Skip to content

Commit

Permalink
fixed bug where num_layers was ignored when invariants_only=False
Browse files Browse the repository at this point in the history
  • Loading branch information
RokasEl committed Nov 18, 2024
1 parent fbc62fa commit f545df3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
18 changes: 12 additions & 6 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,24 +390,30 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions)
if num_layers == -1:
num_layers = int(self.models[0].num_interactions)
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = num_invariant_features # Equivariant features not created for the last layer

if invariants_only:
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_features = irreps_out.dim // (l_max + 1) ** 2
descriptors = [
extract_invariant(
descriptor,
num_layers=num_layers,
num_features=num_features,
num_features=num_invariant_features,
l_max=l_max,
)
for descriptor in descriptors
]
descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [descriptor[:,:to_keep].detach().cpu().numpy() for descriptor in descriptors]

if self.num_models == 1:
return descriptors[0]
Expand Down
20 changes: 14 additions & 6 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,25 +474,33 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):

desc_invariant = calc.get_descriptors(at, invariants_only=True)
desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True)
desc_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1)
desc_single_layer_rotated = calc.get_descriptors(
desc_invariant_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1)
desc_invariant_single_layer_rotated = calc.get_descriptors(
at_rotated, invariants_only=True, num_layers=1
)
desc = calc.get_descriptors(at, invariants_only=False)
desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1)
desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False)
desc_rotated_single_layer = calc.get_descriptors(at_rotated, invariants_only=False, num_layers=1)

assert desc_invariant.shape[0] == 3
assert desc_invariant.shape[1] == 32
assert desc_single_layer.shape[0] == 3
assert desc_single_layer.shape[1] == 16
assert desc_invariant_single_layer.shape[0] == 3
assert desc_invariant_single_layer.shape[1] == 16
assert desc.shape[0] == 3
assert desc.shape[1] == 80
assert desc_single_layer.shape[0] == 3
assert desc_single_layer.shape[1] == 16*4
assert desc_rotated_single_layer.shape[0] == 3
assert desc_rotated_single_layer.shape[1] == 16*4

np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6)
np.testing.assert_allclose(desc_single_layer, desc_invariant[:, :16], atol=1e-6)
np.testing.assert_allclose(desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6)
np.testing.assert_allclose(
desc_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(desc_single_layer[:,:16], desc_rotated_single_layer[:,:16], atol=1e-6)
assert not np.allclose(desc_single_layer[:,16:], desc_rotated_single_layer[:,16:], atol=1e-6)
assert not np.allclose(desc, desc_rotated, atol=1e-6)


Expand Down

0 comments on commit f545df3

Please sign in to comment.