Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
RokasEl committed Nov 18, 2024
1 parent f545df3 commit a6a729a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
10 changes: 7 additions & 3 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,14 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = num_invariant_features # Equivariant features not created for the last layer
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)

if invariants_only:
descriptors = [
Expand All @@ -413,7 +415,9 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
for descriptor in descriptors
]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [descriptor[:,:to_keep].detach().cpu().numpy() for descriptor in descriptors]
descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
]

if self.num_models == 1:
return descriptors[0]
Expand Down
24 changes: 17 additions & 7 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,18 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):

desc_invariant = calc.get_descriptors(at, invariants_only=True)
desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True)
desc_invariant_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1)
desc_invariant_single_layer = calc.get_descriptors(
at, invariants_only=True, num_layers=1
)
desc_invariant_single_layer_rotated = calc.get_descriptors(
at_rotated, invariants_only=True, num_layers=1
)
desc = calc.get_descriptors(at, invariants_only=False)
desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1)
desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False)
desc_rotated_single_layer = calc.get_descriptors(at_rotated, invariants_only=False, num_layers=1)
desc_rotated_single_layer = calc.get_descriptors(
at_rotated, invariants_only=False, num_layers=1
)

assert desc_invariant.shape[0] == 3
assert desc_invariant.shape[1] == 32
Expand All @@ -490,17 +494,23 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
assert desc.shape[0] == 3
assert desc.shape[1] == 80
assert desc_single_layer.shape[0] == 3
assert desc_single_layer.shape[1] == 16*4
assert desc_single_layer.shape[1] == 16 * 4
assert desc_rotated_single_layer.shape[0] == 3
assert desc_rotated_single_layer.shape[1] == 16*4
assert desc_rotated_single_layer.shape[1] == 16 * 4

np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6)
np.testing.assert_allclose(desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6)
np.testing.assert_allclose(
desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(
desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(desc_single_layer[:,:16], desc_rotated_single_layer[:,:16], atol=1e-6)
assert not np.allclose(desc_single_layer[:,16:], desc_rotated_single_layer[:,16:], atol=1e-6)
np.testing.assert_allclose(
desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6
)
assert not np.allclose(
desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6
)
assert not np.allclose(desc, desc_rotated, atol=1e-6)


Expand Down

0 comments on commit a6a729a

Please sign in to comment.