From f545df3201142e446287387de89ddd8d07b27c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rokas=20Elijo=C5=A1ius?= Date: Mon, 18 Nov 2024 16:07:41 +0000 Subject: [PATCH] fixed bug where num_layers was ignored when invariants_only=False --- mace/calculators/mace.py | 18 ++++++++++++------ tests/test_calculator.py | 20 ++++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index dcd2b8e5..d456852e 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -390,24 +390,30 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): atoms = self.atoms if self.model_type != "MACE": raise NotImplementedError("Only implemented for MACE models") + num_interactions = int(self.models[0].num_interactions) if num_layers == -1: - num_layers = int(self.models[0].num_interactions) + num_layers = num_interactions batch = self._atoms_to_batch(atoms) descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] + + irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] + l_max = irreps_out.lmax + num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 + per_layer_features = [irreps_out.dim for _ in range(num_interactions)] + per_layer_features[-1] = num_invariant_features # Equivariant features not created for the last layer + if invariants_only: - irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] - l_max = irreps_out.lmax - num_features = irreps_out.dim // (l_max + 1) ** 2 descriptors = [ extract_invariant( descriptor, num_layers=num_layers, - num_features=num_features, + num_features=num_invariant_features, l_max=l_max, ) for descriptor in descriptors ] - descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors] + to_keep = np.sum(per_layer_features[:num_layers]) + descriptors = [descriptor[:,:to_keep].detach().cpu().numpy() for descriptor in descriptors] if self.num_models == 1: return descriptors[0] diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 97db4da3..1796f908 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -474,25 +474,33 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model): desc_invariant = calc.get_descriptors(at, invariants_only=True) desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1) - desc_single_layer_rotated = calc.get_descriptors( + desc_invariant_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1) + desc_invariant_single_layer_rotated = calc.get_descriptors( at_rotated, invariants_only=True, num_layers=1 ) desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors(at_rotated, invariants_only=False, num_layers=1) assert desc_invariant.shape[0] == 3 assert desc_invariant.shape[1] == 32 - assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 assert desc.shape[0] == 3 assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16*4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16*4 np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose(desc_single_layer, desc_invariant[:, :16], atol=1e-6) + np.testing.assert_allclose(desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6) np.testing.assert_allclose( - desc_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 ) + np.testing.assert_allclose(desc_single_layer[:,:16], desc_rotated_single_layer[:,:16], atol=1e-6) + assert not np.allclose(desc_single_layer[:,16:], desc_rotated_single_layer[:,16:], atol=1e-6) assert not np.allclose(desc, desc_rotated, atol=1e-6)