diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index d456852e..46d9884e 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -395,12 +395,14 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): num_layers = num_interactions batch = self._atoms_to_batch(atoms) descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] - + irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] l_max = irreps_out.lmax num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 per_layer_features = [irreps_out.dim for _ in range(num_interactions)] - per_layer_features[-1] = num_invariant_features # Equivariant features not created for the last layer + per_layer_features[-1] = ( + num_invariant_features # Equivariant features not created for the last layer + ) if invariants_only: descriptors = [ @@ -413,7 +415,9 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): for descriptor in descriptors ] to_keep = np.sum(per_layer_features[:num_layers]) - descriptors = [descriptor[:,:to_keep].detach().cpu().numpy() for descriptor in descriptors] + descriptors = [ + descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors + ] if self.num_models == 1: return descriptors[0] diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 1796f908..5f3cce41 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -474,14 +474,18 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model): desc_invariant = calc.get_descriptors(at, invariants_only=True) desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_invariant_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) desc_invariant_single_layer_rotated = calc.get_descriptors( at_rotated, invariants_only=True, num_layers=1 ) desc = calc.get_descriptors(at, invariants_only=False) desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) - desc_rotated_single_layer = calc.get_descriptors(at_rotated, invariants_only=False, num_layers=1) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) assert desc_invariant.shape[0] == 3 assert desc_invariant.shape[1] == 32 @@ -490,17 +494,23 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model): assert desc.shape[0] == 3 assert desc.shape[1] == 80 assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16*4 + assert desc_single_layer.shape[1] == 16 * 4 assert desc_rotated_single_layer.shape[0] == 3 - assert desc_rotated_single_layer.shape[1] == 16*4 + assert desc_rotated_single_layer.shape[1] == 16 * 4 np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose(desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) np.testing.assert_allclose( desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 ) - np.testing.assert_allclose(desc_single_layer[:,:16], desc_rotated_single_layer[:,:16], atol=1e-6) - assert not np.allclose(desc_single_layer[:,16:], desc_rotated_single_layer[:,16:], atol=1e-6) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) assert not np.allclose(desc, desc_rotated, atol=1e-6)