From f1ef56e22078b2935764c25e0d35709b4aa80596 Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sun, 26 Mar 2023 19:53:24 +0100 Subject: [PATCH 1/7] e0 calc --- mace/modules/models.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/mace/modules/models.py b/mace/modules/models.py index 736cf1cb..27a5c5ce 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -151,6 +151,25 @@ def __init__( else: self.readouts.append(LinearReadoutBlock(hidden_irreps)) + def _forward_calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: + """Forward pass internal - e0 calculation + + Notes + ----- + This is really a lookup from the e0 list we have for each node's element. + + Parameters + ---------- + data + + Returns + ------- + e0 + shape: [n_nodes, ] + + """ + return self.atomic_energies_fn(data["node_attrs"]) + def forward( self, data: Dict[str, torch.Tensor], @@ -184,7 +203,7 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) + node_e0 = self._forward_calculate_e0(data) e0 = scatter_sum( src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] @@ -276,8 +295,8 @@ def forward( compute_displacement: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: # Setup - data["positions"].requires_grad_(True) data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 displacement = torch.zeros( (num_graphs, 3, 3), From 3cdac231b9d681b9f2eb422cdca62d8b645e692c Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sun, 26 Mar 2023 21:07:38 +0100 Subject: [PATCH 2/7] refactor of MACE so that ScaleShiftMACE is making use of inheritance - extracted the layer calculations - "ScaleShift" applied to non-duplicated layer calculation code --- mace/modules/models.py | 181 +++++++++++------------------------------ 1 file changed, 47 insertions(+), 134 deletions(-) diff --git a/mace/modules/models.py b/mace/modules/models.py index 27a5c5ce..008a0498 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -13,7 +13,6 @@ from mace.data import AtomicData from mace.tools.scatter import scatter_sum - from .blocks import ( AtomicEnergiesBlock, EquivariantProductBasisBlock, @@ -170,44 +169,22 @@ def _forward_calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """ return self.atomic_energies_fn(data["node_attrs"]) - def forward( + def _forward_calculate_interactions( self, data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) + ) -> torch.Tensor: + """Forward pass internal - interaction energy calculation - # Atomic energies - node_e0 = self._forward_calculate_e0(data) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + Parameters + ---------- + data + Returns + ------- + interaction_energy + shape: [n_nodes, ] + + """ # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( @@ -219,8 +196,7 @@ def forward( edge_feats = self.radial_embedding(lengths) # Interactions - energies = [e0] - node_energies_list = [node_e0] + node_interaction_energies = [] for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): @@ -236,54 +212,11 @@ def forward( sc=sc, node_attrs=data["node_attrs"], ) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_energies_list.append(node_energies) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - print("contributions", contributions.shape) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - # Outputs - forces, virials, stress = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - ) - - return { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - } + node_interaction_energies.append( + readout(node_feats).squeeze(-1) # [n_nodes, ] + ) - -@compile_mode("script") -class ScaleShiftMACE(MACE): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) + return torch.sum(torch.stack(node_interaction_energies, dim=0), dim=0) def forward( self, @@ -317,56 +250,19 @@ def forward( batch=data["batch"], ) - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - - # Interactions - node_es_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] - ) - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } - - # Sum over interactions - node_inter_es = torch.sum( - torch.stack(node_es_list, dim=0), dim=0 - ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) + # Calculate energy contributions + node_e0 = self._forward_calculate_e0(data) + interaction_energies = self._forward_calculate_interactions(data) - # Sum over nodes in graph - inter_e = scatter_sum( - src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs + # Sum over energy contributions + local_energies = node_e0 + interaction_energies + total_energy = scatter_sum( + src=local_energies, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] - # Add E_0 and (scaled) interaction energy - total_energy = e0 + inter_e - node_energy = node_e0 + node_inter_es - + # Calculate derivatives if needed forces, virials, stress = get_outputs( - energy=inter_e, + energy=total_energy, positions=data["positions"], displacement=displacement, cell=data["cell"], @@ -376,17 +272,34 @@ def forward( compute_stress=compute_stress, ) - output = { + return { "energy": total_energy, - "node_energy": node_energy, - "interaction_energy": inter_e, + "local_energies": local_energies, "forces": forces, "virials": virials, "stress": stress, "displacement": displacement, } - return output + +@compile_mode("script") +class ScaleShiftMACE(MACE): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def _forward_calculate_interactions( + self, + data: Dict[str, torch.Tensor], + ) -> torch.Tensor: + return self.scale_shift(super()._forward_calculate_interactions(data)) class BOTNet(torch.nn.Module): From dcf8020d8752c67c0611e6ebf194b13116d7d217 Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sat, 1 Apr 2023 08:20:57 +0100 Subject: [PATCH 3/7] docstrings and variable names --- mace/modules/models.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/mace/modules/models.py b/mace/modules/models.py index 008a0498..04311703 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -36,6 +36,8 @@ @compile_mode("script") class MACE(torch.nn.Module): + """MACE model""" + def __init__( self, r_max: float, @@ -227,6 +229,33 @@ def forward( compute_stress: bool = False, compute_displacement: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: + """Forward pass - evaluation of model + + Parameters + ---------- + data + input graphs, as dictionary + training + If we are in training mode. Computational graphs are retained, for reuse. + compute_force + Compute forces on all nodes + compute_virials, compute_stress + Compute stress and virial, any of these two triggers calculation of both + compute_displacement + Compute symmetric displacements + + + Returns + ------- + results + calculated results + - `energy`: total energy on each graph + - `node_energy`: energy of each node, i.e. local energy on each atom + - `forces`: negative gradient of total energy on each atom (node) + - `virials` & `stress`: same thing expressed differently + - `displacement`: symmetric displacement of cell used for stress + + """ # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) @@ -255,9 +284,9 @@ def forward( interaction_energies = self._forward_calculate_interactions(data) # Sum over energy contributions - local_energies = node_e0 + interaction_energies + node_energy = node_e0 + interaction_energies total_energy = scatter_sum( - src=local_energies, index=data["batch"], dim=-1, dim_size=num_graphs + src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] # Calculate derivatives if needed @@ -274,7 +303,7 @@ def forward( return { "energy": total_energy, - "local_energies": local_energies, + "node_energy": node_energy, "forces": forces, "virials": virials, "stress": stress, @@ -284,6 +313,11 @@ def forward( @compile_mode("script") class ScaleShiftMACE(MACE): + """MACE Scaled and Shifted + + Same as MACE model, but allows for constant shift and rescaling. + """ + def __init__( self, atomic_inter_scale: float, From 8cb55795840d00868754a0e5eb9c8a922da9c103 Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sun, 2 Apr 2023 09:40:55 +0100 Subject: [PATCH 4/7] Unified MACE, DipolesMACE, & EnergyDipolesMACE - DipolesMACE & EnergyDipolesMACE inherit from MACE - EnergyDipolesMACE accepting dict for forward pass + closer match to MACE class - DipoleOnly versions of blocks added explicitly (to be refactored) - JIT for DipolesMACE & EnergyDipolesMACE --- mace/modules/blocks.py | 41 +++- mace/modules/models.py | 468 ++++++++++++++--------------------------- tests/test_models.py | 2 +- 3 files changed, 193 insertions(+), 318 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 382831b9..fa26c0f6 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -66,12 +66,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [... @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + def __init__(self, irreps_in: o3.Irreps): super().__init__() - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@compile_mode("script") +class LinearDipoleOnlyReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps): + super().__init__() + self.irreps_out = o3.Irreps("1x1o") self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] @@ -100,24 +108,38 @@ def __init__( [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] ) irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) - self.equivariant_nonlin = nn.Gate( + self.non_linearity = nn.Gate( irreps_scalars=irreps_scalars, act_scalars=[gate for _, ir in irreps_scalars], irreps_gates=irreps_gates, act_gates=[gate] * len(irreps_gates), irreps_gated=irreps_gated, ) - self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() + self.irreps_nonlin = self.non_linearity.irreps_in.simplify() self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) self.linear_2 = o3.Linear( irreps_in=self.hidden_irreps, irreps_out=self.irreps_out ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.equivariant_nonlin(self.linear_1(x)) + x = self.non_linearity(self.linear_1(x)) return self.linear_2(x) # [n_nodes, 1] +@compile_mode("script") +class NonLinearDipoleOnlyReadoutBlock(NonLinearDipoleReadoutBlock): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Callable, + ): + # fixme: use more reasonable inheritance + super().__init__( + irreps_in=irreps_in, MLP_irreps=MLP_irreps, gate=gate, dipole_only=True + ) + + @compile_mode("script") class AtomicEnergiesBlock(torch.nn.Module): atomic_energies: torch.Tensor @@ -258,7 +280,8 @@ def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): def forward( self, - sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded + sender_or_receiver_node_attrs: torch.Tensor, + # assumes that the node attributes are one-hot encoded edge_feats: torch.Tensor, ): return torch.einsum( diff --git a/mace/modules/models.py b/mace/modules/models.py index 04311703..f2ad909e 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -18,9 +18,11 @@ EquivariantProductBasisBlock, InteractionBlock, LinearDipoleReadoutBlock, + LinearDipoleOnlyReadoutBlock, LinearNodeEmbeddingBlock, LinearReadoutBlock, NonLinearDipoleReadoutBlock, + NonLinearDipoleOnlyReadoutBlock, NonLinearReadoutBlock, RadialEmbeddingBlock, ScaleShiftBlock, @@ -38,6 +40,9 @@ class MACE(torch.nn.Module): """MACE model""" + _LINEAR_READOUT_CLASS = LinearReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock + def __init__( self, r_max: float, @@ -117,13 +122,11 @@ def __init__( self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) for i in range(num_interactions - 1): if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer + hidden_irreps_out = self._last_layer_irreps(hidden_irreps) else: hidden_irreps_out = hidden_irreps inter = interaction_cls( @@ -147,10 +150,19 @@ def __init__( self.products.append(prod) if i == num_interactions - 2: self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) + self._NONLINEAR_READOUT_CLASS(hidden_irreps_out, MLP_irreps, gate) ) else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) + + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """Irreps to use in the last layer - used for initialisation of subclasses + + energy: Select only scalars for last layer + + """ + return o3.Irreps(str(hidden_irreps[0])) def _forward_calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward pass internal - e0 calculation @@ -198,7 +210,7 @@ def _forward_calculate_interactions( edge_feats = self.radial_embedding(lengths) # Interactions - node_interaction_energies = [] + layer_outputs = [] for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): @@ -214,11 +226,10 @@ def _forward_calculate_interactions( sc=sc, node_attrs=data["node_attrs"], ) - node_interaction_energies.append( - readout(node_feats).squeeze(-1) # [n_nodes, ] - ) - return torch.sum(torch.stack(node_interaction_energies, dim=0), dim=0) + layer_outputs.append(readout(node_feats).squeeze(-1)) # [n_nodes, ] + + return torch.sum(torch.stack(layer_outputs, dim=0), dim=0) def forward( self, @@ -414,15 +425,15 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: data.positions.requires_grad = True # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"]) e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] ) # [n_graphs,] # Embeddings - node_feats = self.node_embedding(data.node_attrs) + node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) @@ -431,15 +442,18 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: energies = [e0] for interaction, readout in zip(self.interactions, self.readouts): node_feats = interaction( - node_attrs=data.node_attrs, + node_attrs=data["node_attrs"], node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, - edge_index=data.edge_index, + edge_index=data["edge_index"], ) node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] energy = scatter_sum( - src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_energies, + index=data["batch"], + dim=-1, + dim_size=data["num_graphs"], ) # [n_graphs,] energies.append(energy) @@ -475,15 +489,15 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: data.positions.requires_grad = True # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"]) e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] ) # [n_graphs,] # Embeddings - node_feats = self.node_embedding(data.node_attrs) + node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) @@ -492,11 +506,11 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: node_es_list = [] for interaction, readout in zip(self.interactions, self.readouts): node_feats = interaction( - node_attrs=data.node_attrs, + node_attrs=data["node_attrs"], node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, - edge_index=data.edge_index, + edge_index=data["edge_index"], ) node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } @@ -509,7 +523,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Sum over nodes in graph inter_e = scatter_sum( - src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_inter_es, index=data["batch"], dim=-1, dim_size=data["num_graphs"] ) # [n_graphs,] # Add E_0 and (scaled) interaction energy @@ -525,7 +539,13 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: return output -class AtomicDipolesMACE(torch.nn.Module): +@compile_mode("script") +class AtomicDipolesMACE(MACE): + """MACE model for Dipoles only""" + + _LINEAR_READOUT_CLASS = LinearDipoleOnlyReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearDipoleOnlyReadoutBlock + def __init__( self, r_max: float, @@ -545,114 +565,59 @@ def __init__( atomic_energies: Optional[ None ], # Just here to make it compatible with energy models, MUST be None + radial_MLP: Optional[List[int]] = None, ): - super().__init__() - assert atomic_energies is None - self.r_max = r_max - self.atomic_numbers = atomic_numbers + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( + super().__init__( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - - # Interactions and readouts - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, + max_ell=max_ell, + interaction_cls=interaction_cls, + interaction_cls_first=interaction_cls_first, + num_interactions=num_interactions, + num_elements=num_elements, hidden_irreps=hidden_irreps, + MLP_irreps=MLP_irreps, + atomic_energies=np.array([]), avg_num_neighbors=avg_num_neighbors, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, + atomic_numbers=atomic_numbers, correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, + gate=gate, + radial_MLP=radial_MLP, ) - self.products = torch.nn.ModuleList([prod]) - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """Irreps to use in the last layer - used for initialisation of subclasses - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[1] - ) # Select only l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=inter.irreps_out, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=True - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) - ) + dipole: Select only l=1 vectors for last layer + + """ + return o3.Irreps(str(hidden_irreps[1])) def forward( self, - data: AtomicData, - training=False, + data: Dict[str, torch.Tensor], + training: bool = False, compute_force: bool = False, compute_virials: bool = False, compute_stress: bool = False, - ) -> Dict[str, Any]: - assert compute_force is False - assert compute_virials is False - assert compute_stress is False + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # dipoles and virials / stress not supported simultaneously + error_msg = "AtomicDipolesMACE does not support energy & its derivatives" + assert not compute_force, error_msg + assert not compute_virials, error_msg + assert not compute_stress, error_msg + assert not compute_displacement, error_msg + # Setup - data.positions.requires_grad = True + num_graphs = data["ptr"].numel() - 1 + data["positions"].requires_grad = True if not training: for p in self.parameters(): p.requires_grad = False @@ -660,60 +625,37 @@ def forward( for p in self.parameters(): p.requires_grad = True - # Embeddings - node_feats = self.node_embedding(data.node_attrs) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) + # Evaluate layer outputs + atomic_dipoles = self._forward_calculate_interactions(data) - # Interactions - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data.node_attrs, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data.edge_index, - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data.node_attrs - ) - node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - dipoles.append(node_dipoles) - - # Compute the dipoles - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + # Sum over dipole contributions + baseline_dipole = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] total_dipole = scatter_sum( src=atomic_dipoles, - index=data.batch.unsqueeze(-1), + index=data["batch"].unsqueeze(-1), dim=0, - dim_size=data.num_graphs, + dim_size=num_graphs, ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data.charges, - positions=data.positions, - batch=data.batch, - num_graphs=data.num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline + total_dipole += baseline_dipole - output = { + return { "dipole": total_dipole, "atomic_dipoles": atomic_dipoles, } - return output +@compile_mode("script") +class EnergyDipolesMACE(MACE): + """MACE model for Energy & Dipoles""" + + _LINEAR_READOUT_CLASS = LinearDipoleReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearDipoleReadoutBlock -class EnergyDipolesMACE(torch.nn.Module): def __init__( self, r_max: float, @@ -726,119 +668,63 @@ def __init__( num_elements: int, hidden_irreps: o3.Irreps, MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, avg_num_neighbors: float, atomic_numbers: List[int], correlation: int, gate: Optional[Callable], - atomic_energies: Optional[np.ndarray], + radial_MLP: Optional[List[int]] = None, ): - super().__init__() - self.r_max = r_max - self.atomic_numbers = atomic_numbers - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + + super().__init__( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, + max_ell=max_ell, + interaction_cls=interaction_cls, + interaction_cls_first=interaction_cls_first, + num_interactions=num_interactions, + num_elements=num_elements, hidden_irreps=hidden_irreps, + MLP_irreps=MLP_irreps, + atomic_energies=atomic_energies, avg_num_neighbors=avg_num_neighbors, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, + atomic_numbers=atomic_numbers, correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, + gate=gate, + radial_MLP=radial_MLP, ) - self.products = torch.nn.ModuleList([prod]) - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """Irreps to use in the last layer - used for initialisation of subclasses - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[:2] - ) # Select scalars and l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=inter.irreps_out, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=False - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) - ) + energy & dipole: Select scalars and l=1 vectors for last layer + + """ + return o3.Irreps(str(hidden_irreps[:2])) def forward( self, - data: AtomicData, - training=False, + data: Dict[str, torch.Tensor], + training: bool = False, compute_force: bool = True, compute_virials: bool = False, compute_stress: bool = False, - ) -> Dict[str, Any]: + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: # dipoles and virials / stress not supported simultaneously - assert compute_virials is False - assert compute_stress is False + error_msg = "dipoles and virials / stress not supported simultaneously" + assert not compute_virials, error_msg + assert not compute_stress, error_msg + assert not compute_displacement, error_msg + # Setup - data.positions.requires_grad = True + num_graphs = data["ptr"].numel() - 1 + data["positions"].requires_grad = True if not training: for p in self.parameters(): p.requires_grad = False @@ -847,84 +733,50 @@ def forward( p.requires_grad = True # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) - e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] + node_e0 = self._forward_calculate_e0(data) - # Embeddings - node_feats = self.node_embedding(data.node_attrs) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) + # Evaluate layer outputs + layer_outputs = self._forward_calculate_interactions(data) + interaction_energies = layer_outputs[:, 0] + atomic_dipoles = layer_outputs[:, 1:] - # Interactions - energies = [e0] - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data.node_attrs, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data.edge_index, - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data.node_attrs - ) - node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] - # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - node_energies = node_out[:, 0] - energy = scatter_sum( - src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - energies.append(energy) - # node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - node_dipoles = node_out[:, 1:] - dipoles.append(node_dipoles) + # Sum over energy contributions + node_energy = node_e0 + interaction_energies + total_energy = scatter_sum( + src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] - # Compute the energies and dipoles - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + # Sum over dipole contributions + baseline_dipole = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] total_dipole = scatter_sum( src=atomic_dipoles, - index=data.batch.unsqueeze(-1), + index=data["batch"].unsqueeze(-1), dim=0, - dim_size=data.num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data.charges, - positions=data.positions, - batch=data.batch, - num_graphs=data.num_graphs, + dim_size=num_graphs, ) # [n_graphs,3] - total_dipole = total_dipole + baseline + total_dipole += baseline_dipole + # Calculate derivatives if needed forces, _, _ = get_outputs( energy=total_energy, - positions=data.positions, + positions=data["positions"], displacement=None, - cell=data.cell, + cell=data["cell"], training=training, compute_force=compute_force, compute_virials=compute_virials, compute_stress=compute_stress, ) - output = { + return { "energy": total_energy, - "contributions": contributions, + "node_energy": node_energy, "forces": forces, "dipole": total_dipole, "atomic_dipoles": atomic_dipoles, } - - return output diff --git a/tests/test_models.py b/tests/test_models.py index f39c437f..d0177e16 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -182,7 +182,7 @@ def test_energy_dipole_mace(): ) batch = next(iter(data_loader)) output = model( - batch, + batch.to_dict(), training=True, ) # sanity check of dipoles being the right shape From 34e82e0934a52001cba0658ddaa5a1e6e153cc7c Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sun, 2 Apr 2023 12:55:56 +0100 Subject: [PATCH 5/7] cleaner mixin classes for models --- mace/modules/models.py | 649 ++++++++++++++++++++++++++++------------- 1 file changed, 441 insertions(+), 208 deletions(-) diff --git a/mace/modules/models.py b/mace/modules/models.py index f2ad909e..8def3dfa 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -4,7 +4,7 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Tuple import numpy as np import torch @@ -37,8 +37,15 @@ @compile_mode("script") -class MACE(torch.nn.Module): - """MACE model""" +class MaceCoreModel(torch.nn.Module): + """Core model for all MACE models + + Includes the following + - graph parameters: r_max, elements, etc. + - embeddings (node, edge) + - readout blocks (subclasses can change settings of these) + + """ _LINEAR_READOUT_CLASS = LinearReadoutBlock _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock @@ -55,14 +62,49 @@ def __init__( num_elements: int, hidden_irreps: o3.Irreps, MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, avg_num_neighbors: float, atomic_numbers: List[int], correlation: int, gate: Optional[Callable], radial_MLP: Optional[List[int]] = None, + **kwargs, ): - super().__init__() + """Core MACE model + + Parameters + ---------- + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + -> ??? + max_ell + l_max + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + """ + super().__init__(**kwargs) + + # Main Buffers self.register_buffer( "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) ) @@ -70,6 +112,7 @@ def __init__( self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) + # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) @@ -91,9 +134,8 @@ def __init__( ) if radial_MLP is None: radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + # Interactions and readout inter = interaction_cls_first( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, @@ -159,35 +201,18 @@ def __init__( def _last_layer_irreps(hidden_irreps) -> o3.Irreps: """Irreps to use in the last layer - used for initialisation of subclasses - energy: Select only scalars for last layer + core model: drops highest l irreps, unless it's scalar only """ - return o3.Irreps(str(hidden_irreps[0])) - - def _forward_calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward pass internal - e0 calculation + if len(hidden_irreps) == 1: + return hidden_irreps + return o3.Irreps(str(hidden_irreps[:-1])) - Notes - ----- - This is really a lookup from the e0 list we have for each node's element. - - Parameters - ---------- - data - - Returns - ------- - e0 - shape: [n_nodes, ] - - """ - return self.atomic_energies_fn(data["node_attrs"]) - - def _forward_calculate_interactions( + def _calculate_layer_interactions( self, data: Dict[str, torch.Tensor], ) -> torch.Tensor: - """Forward pass internal - interaction energy calculation + """Calculate the layer interactions - used within forward pass Parameters ---------- @@ -195,8 +220,8 @@ def _forward_calculate_interactions( Returns ------- - interaction_energy - shape: [n_nodes, ] + layer_outputs + shape: [n_nodes, len(ouptuts)] """ # Embeddings @@ -231,6 +256,188 @@ def _forward_calculate_interactions( return torch.sum(torch.stack(layer_outputs, dim=0), dim=0) + +@compile_mode("script") +class EnergyModelMixin(torch.nn.Module): + """Mixin class for energy models + + Supplies: + - e0: atomic energy block + + """ + + def __init__(self, atomic_energies: np.ndarray, **kwargs): + super().__init__(**kwargs) + + # Energy calculation specific bits + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + def _calc_energy( + self, + data: Dict[str, torch.Tensor], + layer_output_energies: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_graphs = data["ptr"].numel() - 1 + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + + # Sum over energy contributions + node_energy = node_e0 + layer_output_energies + total_energy = scatter_sum( + src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + return total_energy, node_energy + + def _calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: + """e0 calculation, element-wise energy shift + + Notes + ----- + This is really a lookup from the e0 list we have for each node's element. + + """ + return self.atomic_energies_fn(data["node_attrs"]) + + +@compile_mode("script") +class ScaleShiftEnergyModelMixin(EnergyModelMixin): + """Scaled and shifted energy model""" + + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + """ + + Parameters + ---------- + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + """ + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def _calc_energy( + self, + data: Dict[str, torch.Tensor], + layer_output_energies: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return super()._calc_energy(data, self.scale_shift(layer_output_energies)) + + +@compile_mode("script") +class DipoleModelMixin(torch.nn.Module): + """Mixin class for dipole models + + Supplies: + - function to calculate total dipole, including fixed charge baseline + + """ + + @staticmethod + def _calc_total_dipole( + data: Dict[str, torch.Tensor], + atomic_dipoles: torch.Tensor, + ) -> torch.Tensor: + """Calculates total dipoles - adding fixed charge baseline + + Parameters + ---------- + data + atomic_dipoles + corresponding output of the interaction layers + + Returns + ------- + total_dipole + shape: [n_graphs,3] + + """ + num_graphs = data["ptr"].numel() - 1 + + # Sum over dipole contributions + baseline_dipole = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + total_dipole += baseline_dipole + + return total_dipole # [n_graphs,3] + + +@compile_mode("script") +class MACE(MaceCoreModel, EnergyModelMixin): + """MACE model""" + + _LINEAR_READOUT_CLASS = LinearReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock + + def __init__(self, **kwargs): + """MACE energy model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + + **kwargs + """ + super().__init__(**kwargs) + + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """energy: Select only scalars for last layer""" + return o3.Irreps(str(hidden_irreps[0])) + def forward( self, data: Dict[str, torch.Tensor], @@ -255,7 +462,6 @@ def forward( compute_displacement Compute symmetric displacements - Returns ------- results @@ -290,15 +496,10 @@ def forward( batch=data["batch"], ) - # Calculate energy contributions - node_e0 = self._forward_calculate_e0(data) - interaction_energies = self._forward_calculate_interactions(data) - - # Sum over energy contributions - node_energy = node_e0 + interaction_energies - total_energy = scatter_sum( - src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + # Calculate energies + total_energy, node_energy = self._calc_energy( + data, self._calculate_layer_interactions(data) + ) # Calculate derivatives if needed forces, virials, stress = get_outputs( @@ -323,28 +524,58 @@ def forward( @compile_mode("script") -class ScaleShiftMACE(MACE): +class ScaleShiftMACE(MACE, ScaleShiftEnergyModelMixin): """MACE Scaled and Shifted Same as MACE model, but allows for constant shift and rescaling. """ - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) + def __init__(self, **kwargs): + """Scaled & Shifted MACE energy model - def _forward_calculate_interactions( - self, - data: Dict[str, torch.Tensor], - ) -> torch.Tensor: - return self.scale_shift(super()._forward_calculate_interactions(data)) + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + """ + super().__init__(**kwargs) class BOTNet(torch.nn.Module): @@ -364,8 +595,9 @@ def __init__( gate: Optional[Callable], avg_num_neighbors: float, atomic_numbers: List[int], + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.r_max = r_max self.atomic_numbers = atomic_numbers # Embedding @@ -540,63 +772,62 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: @compile_mode("script") -class AtomicDipolesMACE(MACE): +class AtomicDipolesMACE(MaceCoreModel, DipoleModelMixin): """MACE model for Dipoles only""" _LINEAR_READOUT_CLASS = LinearDipoleOnlyReadoutBlock _NONLINEAR_READOUT_CLASS = NonLinearDipoleOnlyReadoutBlock - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[ - None - ], # Just here to make it compatible with energy models, MUST be None - radial_MLP: Optional[List[int]] = None, - ): + def __init__(self, hidden_irreps: o3.Irreps, atomic_energies=None, **kwargs): + """MACE dipole only model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + **kwargs + + Notes + ----- + Requires at least l=1 features to be used for representing dipoles. + `atomic_energies` parameter added for compatibility, but it's ignored + + """ assert ( - len(hidden_irreps) > 1 + hidden_irreps.lmax >= 1 ), "To predict dipoles use at least l=1 hidden_irreps" - - super().__init__( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - max_ell=max_ell, - interaction_cls=interaction_cls, - interaction_cls_first=interaction_cls_first, - num_interactions=num_interactions, - num_elements=num_elements, - hidden_irreps=hidden_irreps, - MLP_irreps=MLP_irreps, - atomic_energies=np.array([]), - avg_num_neighbors=avg_num_neighbors, - atomic_numbers=atomic_numbers, - correlation=correlation, - gate=gate, - radial_MLP=radial_MLP, - ) + super().__init__(hidden_irreps=hidden_irreps, **kwargs) @staticmethod def _last_layer_irreps(hidden_irreps) -> o3.Irreps: - """Irreps to use in the last layer - used for initialisation of subclasses - - dipole: Select only l=1 vectors for last layer - - """ + """dipole: Select only l=1 vectors for last layer""" return o3.Irreps(str(hidden_irreps[1])) def forward( @@ -608,6 +839,27 @@ def forward( compute_stress: bool = False, compute_displacement: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: + """Forward pass - evaluation of model + + Parameters + ---------- + data + input graphs, as dictionary + training + If we are in training mode. Computational graphs are retained, for reuse. + compute_force, compute_virials, compute_stress, compute_displacement + compatibility only, RAISES ERROR IF True + + Returns + ------- + results + calculated results + - `dipole`: dipole of the graph, includes atomic and fixed charge dipole + components as well + - `atomic_dipoles`: dipoles of each atom + + """ + # dipoles and virials / stress not supported simultaneously error_msg = "AtomicDipolesMACE does not support energy & its derivatives" assert not compute_force, error_msg @@ -616,95 +868,79 @@ def forward( assert not compute_displacement, error_msg # Setup - num_graphs = data["ptr"].numel() - 1 data["positions"].requires_grad = True - if not training: - for p in self.parameters(): - p.requires_grad = False - else: - for p in self.parameters(): - p.requires_grad = True # Evaluate layer outputs - atomic_dipoles = self._forward_calculate_interactions(data) - - # Sum over dipole contributions - baseline_dipole = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"].unsqueeze(-1), - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - total_dipole += baseline_dipole + atomic_dipoles = self._calculate_layer_interactions(data) return { - "dipole": total_dipole, + "dipole": self._calc_total_dipole(data, atomic_dipoles), "atomic_dipoles": atomic_dipoles, } @compile_mode("script") -class EnergyDipolesMACE(MACE): +class EnergyDipolesMACE(MACE, DipoleModelMixin): """MACE model for Energy & Dipoles""" _LINEAR_READOUT_CLASS = LinearDipoleReadoutBlock _NONLINEAR_READOUT_CLASS = NonLinearDipoleReadoutBlock - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - radial_MLP: Optional[List[int]] = None, - ): + def __init__(self, hidden_irreps: o3.Irreps, **kwargs): + """MACE energy & dipole model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + + **kwargs + + Notes + ----- + Requires at least l=1 features to be used for representing dipoles. + `atomic_energies` parameter added for compatibility, but it's ignored + + """ assert ( - len(hidden_irreps) > 1 + hidden_irreps.lmax >= 1 ), "To predict dipoles use at least l=1 hidden_irreps" - - super().__init__( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - max_ell=max_ell, - interaction_cls=interaction_cls, - interaction_cls_first=interaction_cls_first, - num_interactions=num_interactions, - num_elements=num_elements, - hidden_irreps=hidden_irreps, - MLP_irreps=MLP_irreps, - atomic_energies=atomic_energies, - avg_num_neighbors=avg_num_neighbors, - atomic_numbers=atomic_numbers, - correlation=correlation, - gate=gate, - radial_MLP=radial_MLP, - ) + super().__init__(hidden_irreps=hidden_irreps, **kwargs) @staticmethod def _last_layer_irreps(hidden_irreps) -> o3.Irreps: - """Irreps to use in the last layer - used for initialisation of subclasses - - energy & dipole: Select scalars and l=1 vectors for last layer - - """ + """energy & dipole: Select scalars and l=1 vectors for last layer""" return o3.Irreps(str(hidden_irreps[:2])) def forward( @@ -716,6 +952,31 @@ def forward( compute_stress: bool = False, compute_displacement: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: + """Forward pass - evaluation of model + + Parameters + ---------- + data + input graphs, as dictionary + training + If we are in training mode. Computational graphs are retained, for reuse. + compute_force + Compute forces on all nodes + compute_virials, compute_stress, compute_displacement + compatibility only, RAISES ERROR IF True + + Returns + ------- + results + calculated results + - `energy`: total energy on each graph + - `node_energy`: energy of each node, i.e. local energy on each atom + - `forces`: negative gradient of total energy on each atom (node) + - `dipole`: dipole of the graph, includes atomic and fixed charge dipole + components as well + - `atomic_dipoles`: dipoles of each atom + + """ # dipoles and virials / stress not supported simultaneously error_msg = "dipoles and virials / stress not supported simultaneously" assert not compute_virials, error_msg @@ -723,43 +984,15 @@ def forward( assert not compute_displacement, error_msg # Setup - num_graphs = data["ptr"].numel() - 1 data["positions"].requires_grad = True - if not training: - for p in self.parameters(): - p.requires_grad = False - else: - for p in self.parameters(): - p.requires_grad = True - - # Atomic energies - node_e0 = self._forward_calculate_e0(data) - # Evaluate layer outputs - layer_outputs = self._forward_calculate_interactions(data) + # Evaluate layer outputs & unpack + layer_outputs = self._calculate_layer_interactions(data) interaction_energies = layer_outputs[:, 0] atomic_dipoles = layer_outputs[:, 1:] - # Sum over energy contributions - node_energy = node_e0 + interaction_energies - total_energy = scatter_sum( - src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Sum over dipole contributions - baseline_dipole = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"].unsqueeze(-1), - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - total_dipole += baseline_dipole + # Calculate energies + total_energy, node_energy = self._calc_energy(data, interaction_energies) # Calculate derivatives if needed forces, _, _ = get_outputs( @@ -777,6 +1010,6 @@ def forward( "energy": total_energy, "node_energy": node_energy, "forces": forces, - "dipole": total_dipole, + "dipole": self._calc_total_dipole(data, atomic_dipoles), "atomic_dipoles": atomic_dipoles, } From 5cf332737a4ebffbc32e1d683d55583902d130e3 Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sun, 2 Apr 2023 15:13:30 +0100 Subject: [PATCH 6/7] refactor the non-high level model classes into a new file --- mace/modules/core_models.py | 378 ++++++++++++++++ mace/modules/models.py | 851 +++++++++++------------------------- 2 files changed, 635 insertions(+), 594 deletions(-) create mode 100644 mace/modules/core_models.py diff --git a/mace/modules/core_models.py b/mace/modules/core_models.py new file mode 100644 index 00000000..61b9a42a --- /dev/null +++ b/mace/modules/core_models.py @@ -0,0 +1,378 @@ +"""Core model classes, and mixins for various quantities of interest + +Notes +----- +`MaceCoreModel` is the backbone of MACE models, stores all needed blocks and allows +for customisable readout shapes. + +Mixin classes are defined for calculating quantities: +- EnergyModelMixin: inherit from if you want to calculate energies +- ScaleShiftEnergyModelMixin: energies with scaling and shifting +- DipoleModelMixin: dipole learning + +""" +from typing import Type, List, Optional, Callable, Dict, Tuple + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.modules import ( + LinearReadoutBlock, + NonLinearReadoutBlock, + InteractionBlock, + LinearNodeEmbeddingBlock, + RadialEmbeddingBlock, + EquivariantProductBasisBlock, + AtomicEnergiesBlock, + ScaleShiftBlock, +) +from mace.modules.utils import get_edge_vectors_and_lengths, compute_fixed_charge_dipole +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class MaceCoreModel(torch.nn.Module): + """Core model for all MACE models + + Includes the following + - graph parameters: r_max, elements, etc. + - embeddings (node, edge) + - readout blocks (subclasses can change settings of these) + + """ + + _LINEAR_READOUT_CLASS = LinearReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock + + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + radial_MLP: Optional[List[int]] = None, + **kwargs, + ): + """Core MACE model + + Parameters + ---------- + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + """ + super().__init__(**kwargs) + + # Main Buffers + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + + # Interactions and readout + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = self._last_layer_irreps(hidden_irreps) + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + self._NONLINEAR_READOUT_CLASS(hidden_irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) + + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """Irreps to use in the last layer - used for initialisation of subclasses + + core model: drops highest l irreps, unless it's scalar only + + """ + if len(hidden_irreps) == 1: + return hidden_irreps + return o3.Irreps(str(hidden_irreps[:-1])) + + def _calculate_layer_interactions( + self, + data: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """Calculate the layer interactions - used within forward pass + + Parameters + ---------- + data + + Returns + ------- + layer_outputs + shape: [n_nodes, len(ouptuts)] + + """ + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + # Interactions + layer_outputs = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + + layer_outputs.append(readout(node_feats).squeeze(-1)) # [n_nodes, ] + + return torch.sum(torch.stack(layer_outputs, dim=0), dim=0) + + +@compile_mode("script") +class EnergyModelMixin(torch.nn.Module): + """Mixin class for energy models + + Supplies: + - e0: atomic energy block + + """ + + def __init__(self, atomic_energies: np.ndarray, **kwargs): + super().__init__(**kwargs) + + # Energy calculation specific bits + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + def _calc_energy( + self, + data: Dict[str, torch.Tensor], + layer_output_energies: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_graphs = data["ptr"].numel() - 1 + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + + # Sum over energy contributions + node_energy = node_e0 + layer_output_energies + total_energy = scatter_sum( + src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + return total_energy, node_energy + + def _calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: + """e0 calculation, element-wise energy shift + + Notes + ----- + This is really a lookup from the e0 list we have for each node's element. + + """ + return self.atomic_energies_fn(data["node_attrs"]) + + +@compile_mode("script") +class ScaleShiftEnergyModelMixin(EnergyModelMixin): + """Scaled and shifted energy model""" + + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + """ + + Parameters + ---------- + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + """ + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def _calc_energy( + self, + data: Dict[str, torch.Tensor], + layer_output_energies: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return super()._calc_energy(data, self.scale_shift(layer_output_energies)) + + +@compile_mode("script") +class DipoleModelMixin(torch.nn.Module): + """Mixin class for dipole models + + Supplies: + - function to calculate total dipole, including fixed charge baseline + + """ + + @staticmethod + def _calc_total_dipole( + data: Dict[str, torch.Tensor], + atomic_dipoles: torch.Tensor, + ) -> torch.Tensor: + """Calculates total dipoles - adding fixed charge baseline + + Parameters + ---------- + data + atomic_dipoles + corresponding output of the interaction layers + + Returns + ------- + total_dipole + shape: [n_graphs,3] + + """ + num_graphs = data["ptr"].numel() - 1 + + # Sum over dipole contributions + baseline_dipole = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + total_dipole += baseline_dipole + + return total_dipole # [n_graphs,3] diff --git a/mace/modules/models.py b/mace/modules/models.py index 8def3dfa..23a544d8 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -1,10 +1,16 @@ -########################################################################################### +######################################################################################## # Implementation of MACE models and other models based E(3)-Equivariant MPNNs # Authors: Ilyes Batatia, Gregor Simm # This program is distributed under the MIT License (see MIT.md) -########################################################################################### +######################################################################################## +"""Models -from typing import Any, Callable, Dict, List, Optional, Type, Tuple +These are the high-level models to be exposed. Internal workings are kept in the class +hierarchy. + +""" + +from typing import Any, Callable, Dict, List, Optional, Type import numpy as np import torch @@ -15,7 +21,6 @@ from mace.tools.scatter import scatter_sum from .blocks import ( AtomicEnergiesBlock, - EquivariantProductBasisBlock, InteractionBlock, LinearDipoleReadoutBlock, LinearDipoleOnlyReadoutBlock, @@ -27,8 +32,13 @@ RadialEmbeddingBlock, ScaleShiftBlock, ) +from .core_models import ( + MaceCoreModel, + EnergyModelMixin, + ScaleShiftEnergyModelMixin, + DipoleModelMixin, +) from .utils import ( - compute_fixed_charge_dipole, compute_forces, get_edge_vectors_and_lengths, get_outputs, @@ -36,356 +46,9 @@ ) -@compile_mode("script") -class MaceCoreModel(torch.nn.Module): - """Core model for all MACE models - - Includes the following - - graph parameters: r_max, elements, etc. - - embeddings (node, edge) - - readout blocks (subclasses can change settings of these) - - """ - - _LINEAR_READOUT_CLASS = LinearReadoutBlock - _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock - - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - radial_MLP: Optional[List[int]] = None, - **kwargs, - ): - """Core MACE model - - Parameters - ---------- - r_max - cutoff of radial embedding applied to individual atoms - num_bessel - number of Bessel functions to be used for radial embedding - num_polynomial_cutoff - -> ??? - max_ell - l_max - interaction_cls - class to be used for interactions blocks - interaction_cls_first - class to be used for the first layer's interaction block - num_interactions - number of interaction layers to use - num_elements - redundant parameter for the number of elements - hidden_irreps - hidden irreducible representations, basically the size of the layer features - and hence direct control on the size of the model - MLP_irreps - avg_num_neighbors - atomic_numbers - atomic numbers of elements the model supports - order & length to agree with atomic_energies - correlation - gate - non-linearity for non-linear readouts - radial_MLP - """ - super().__init__(**kwargs) - - # Main Buffers - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - - # Interactions and readout - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = self._last_layer_irreps(hidden_irreps) - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - self._NONLINEAR_READOUT_CLASS(hidden_irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) - - @staticmethod - def _last_layer_irreps(hidden_irreps) -> o3.Irreps: - """Irreps to use in the last layer - used for initialisation of subclasses - - core model: drops highest l irreps, unless it's scalar only - - """ - if len(hidden_irreps) == 1: - return hidden_irreps - return o3.Irreps(str(hidden_irreps[:-1])) - - def _calculate_layer_interactions( - self, - data: Dict[str, torch.Tensor], - ) -> torch.Tensor: - """Calculate the layer interactions - used within forward pass - - Parameters - ---------- - data - - Returns - ------- - layer_outputs - shape: [n_nodes, len(ouptuts)] - - """ - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - - # Interactions - layer_outputs = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - - layer_outputs.append(readout(node_feats).squeeze(-1)) # [n_nodes, ] - - return torch.sum(torch.stack(layer_outputs, dim=0), dim=0) - - -@compile_mode("script") -class EnergyModelMixin(torch.nn.Module): - """Mixin class for energy models - - Supplies: - - e0: atomic energy block - - """ - - def __init__(self, atomic_energies: np.ndarray, **kwargs): - super().__init__(**kwargs) - - # Energy calculation specific bits - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - def _calc_energy( - self, - data: Dict[str, torch.Tensor], - layer_output_energies: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_graphs = data["ptr"].numel() - 1 - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - - # Sum over energy contributions - node_energy = node_e0 + layer_output_energies - total_energy = scatter_sum( - src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - return total_energy, node_energy - - def _calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: - """e0 calculation, element-wise energy shift - - Notes - ----- - This is really a lookup from the e0 list we have for each node's element. - - """ - return self.atomic_energies_fn(data["node_attrs"]) - - -@compile_mode("script") -class ScaleShiftEnergyModelMixin(EnergyModelMixin): - """Scaled and shifted energy model""" - - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - """ - - Parameters - ---------- - atomic_inter_scale - scale of interaction energy - atomic_inter_shift - constant shift of interaction energy (per atom) - - **kwargs - """ - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) - - def _calc_energy( - self, - data: Dict[str, torch.Tensor], - layer_output_energies: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - return super()._calc_energy(data, self.scale_shift(layer_output_energies)) - - -@compile_mode("script") -class DipoleModelMixin(torch.nn.Module): - """Mixin class for dipole models - - Supplies: - - function to calculate total dipole, including fixed charge baseline - - """ - - @staticmethod - def _calc_total_dipole( - data: Dict[str, torch.Tensor], - atomic_dipoles: torch.Tensor, - ) -> torch.Tensor: - """Calculates total dipoles - adding fixed charge baseline - - Parameters - ---------- - data - atomic_dipoles - corresponding output of the interaction layers - - Returns - ------- - total_dipole - shape: [n_graphs,3] - - """ - num_graphs = data["ptr"].numel() - 1 - - # Sum over dipole contributions - baseline_dipole = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"].unsqueeze(-1), - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - total_dipole += baseline_dipole - - return total_dipole # [n_graphs,3] - - @compile_mode("script") class MACE(MaceCoreModel, EnergyModelMixin): - """MACE model""" + """MACE model - energy only""" _LINEAR_READOUT_CLASS = LinearReadoutBlock _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock @@ -524,256 +187,63 @@ def forward( @compile_mode("script") -class ScaleShiftMACE(MACE, ScaleShiftEnergyModelMixin): - """MACE Scaled and Shifted - - Same as MACE model, but allows for constant shift and rescaling. - """ - - def __init__(self, **kwargs): - """Scaled & Shifted MACE energy model - - Parameters - ---------- - # ------- core model parameters - r_max - cutoff of radial embedding applied to individual atoms - num_bessel - number of Bessel functions to be used for radial embedding - num_polynomial_cutoff - max_ell - l_max, maximum channels for spherical harmonics - interaction_cls - class to be used for interactions blocks - interaction_cls_first - class to be used for the first layer's interaction block - num_interactions - number of interaction layers to use - num_elements - redundant parameter for the number of elements - hidden_irreps - hidden irreducible representations, basically the size of the layer features - and hence direct control on the size of the model - MLP_irreps - avg_num_neighbors - atomic_numbers - atomic numbers of elements the model supports - order & length to agree with atomic_energies - correlation - gate - non-linearity for non-linear readouts - radial_MLP - - # ------- energy model specific parameters - atomic_energies - energy shift (e0) of elements - order & length to agree with atomic_numbers - atomic_inter_scale - scale of interaction energy - atomic_inter_shift - constant shift of interaction energy (per atom) - - **kwargs - """ - super().__init__(**kwargs) - - -class BOTNet(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - gate: Optional[Callable], - avg_num_neighbors: float, - atomic_numbers: List[int], - **kwargs, - ): - super().__init__(**kwargs) - self.r_max = r_max - self.atomic_numbers = atomic_numbers - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - self.interactions = torch.nn.ModuleList() - self.readouts = torch.nn.ModuleList() - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - self.readouts.append(LinearReadoutBlock(inter.irreps_out)) - - for i in range(num_interactions - 1): - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=inter.irreps_out, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(LinearReadoutBlock(inter.irreps_out)) - - def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: - # Setup - data.positions.requires_grad = True - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - - # Interactions - energies = [e0] - for interaction, readout in zip(self.interactions, self.readouts): - node_feats = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, - index=data["batch"], - dim=-1, - dim_size=data["num_graphs"], - ) # [n_graphs,] - energies.append(energy) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - - output = { - "energy": total_energy, - "contributions": contributions, - "forces": compute_forces( - energy=total_energy, positions=data.positions, training=training - ), - } - - return output - - -class ScaleShiftBOTNet(BOTNet): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) - - def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: - # Setup - data.positions.requires_grad = True - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - - # Interactions - node_es_list = [] - for interaction, readout in zip(self.interactions, self.readouts): - node_feats = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } +class ScaleShiftMACE(MACE, ScaleShiftEnergyModelMixin): + """Scaled and Shifted MACE energy model - # Sum over interactions - node_inter_es = torch.sum( - torch.stack(node_es_list, dim=0), dim=0 - ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) + Same as MACE model, but allows for constant shift and rescaling of energy + """ - # Sum over nodes in graph - inter_e = scatter_sum( - src=node_inter_es, index=data["batch"], dim=-1, dim_size=data["num_graphs"] - ) # [n_graphs,] + def __init__(self, **kwargs): + """Scaled & Shifted MACE energy model - # Add E_0 and (scaled) interaction energy - total_e = e0 + inter_e + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP - output = { - "energy": total_e, - "forces": compute_forces( - energy=inter_e, positions=data.positions, training=training - ), - } + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) - return output + **kwargs + """ + super().__init__(**kwargs) @compile_mode("script") class AtomicDipolesMACE(MaceCoreModel, DipoleModelMixin): - """MACE model for Dipoles only""" + """MACE model for dipoles only""" _LINEAR_READOUT_CLASS = LinearDipoleOnlyReadoutBlock _NONLINEAR_READOUT_CLASS = NonLinearDipoleOnlyReadoutBlock @@ -1013,3 +483,196 @@ def forward( "dipole": self._calc_total_dipole(data, atomic_dipoles), "atomic_dipoles": atomic_dipoles, } + + +class BOTNet(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + gate: Optional[Callable], + avg_num_neighbors: float, + atomic_numbers: List[int], + **kwargs, + ): + super().__init__(**kwargs) + self.r_max = r_max + self.atomic_numbers = atomic_numbers + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + self.interactions = torch.nn.ModuleList() + self.readouts = torch.nn.ModuleList() + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + + for i in range(num_interactions - 1): + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=inter.irreps_out, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + + def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: + # Setup + data.positions.requires_grad = True + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + # Interactions + energies = [e0] + for interaction, readout in zip(self.interactions, self.readouts): + node_feats = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + energy = scatter_sum( + src=node_energies, + index=data["batch"], + dim=-1, + dim_size=data["num_graphs"], + ) # [n_graphs,] + energies.append(energy) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + + output = { + "energy": total_energy, + "contributions": contributions, + "forces": compute_forces( + energy=total_energy, positions=data.positions, training=training + ), + } + + return output + + +class ScaleShiftBOTNet(BOTNet): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: + # Setup + data.positions.requires_grad = True + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + # Interactions + node_es_list = [] + for interaction, readout in zip(self.interactions, self.readouts): + node_feats = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + + node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es) + + # Sum over nodes in graph + inter_e = scatter_sum( + src=node_inter_es, index=data["batch"], dim=-1, dim_size=data["num_graphs"] + ) # [n_graphs,] + + # Add E_0 and (scaled) interaction energy + total_e = e0 + inter_e + + output = { + "energy": total_e, + "forces": compute_forces( + energy=inter_e, positions=data.positions, training=training + ), + } + + return output From 661677641088c1de74215b24d61822d019a70e57 Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Sun, 2 Apr 2023 15:41:30 +0100 Subject: [PATCH 7/7] easily added ScaleShiftEnergyDipoleMACE --- mace/modules/__init__.py | 1 + mace/modules/models.py | 62 +++++++++++++++++++++ tests/test_models.py | 114 +++++++++++++++++++++++++++------------ 3 files changed, 143 insertions(+), 34 deletions(-) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index f7409f2f..73cb2d1f 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -36,6 +36,7 @@ EnergyDipolesMACE, ScaleShiftBOTNet, ScaleShiftMACE, + ScaleShiftEnergyDipoleMACE, ) from .radial import BesselBasis, PolynomialCutoff from .symmetric_contraction import SymmetricContraction diff --git a/mace/modules/models.py b/mace/modules/models.py index 23a544d8..2d047f02 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -485,6 +485,68 @@ def forward( } +@compile_mode("script") +class ScaleShiftEnergyDipoleMACE(EnergyDipolesMACE, ScaleShiftEnergyModelMixin): + """MACE model with Scaled and Shifted Energy, & Dipoles + + Same as EnergyDipolesMACE model, but allows for constant shift and rescaling of + energy + """ + + def __init__(self, **kwargs): + """MACE energy & dipole model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + + Notes + ----- + Requires at least l=1 features to be used for representing dipoles. + `atomic_energies` parameter added for compatibility, but it's ignored + + """ + super().__init__(**kwargs) + + class BOTNet(torch.nn.Module): def __init__( self, diff --git a/tests/test_models.py b/tests/test_models.py index d0177e16..d42daa71 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch import torch.nn.functional from e3nn import o3 @@ -6,6 +7,7 @@ from scipy.spatial.transform import Rotation as R from mace import data, modules, tools +from mace.modules.models import ScaleShiftEnergyDipoleMACE from mace.tools import torch_geometric torch.set_default_dtype(torch.float64) @@ -50,12 +52,12 @@ atomic_energies = np.array([1.0, 3.0], dtype=float) -def test_mace(): - # Create MACE model - model_config = dict( +@pytest.fixture +def dipole_model_config() -> dict: + return dict( r_max=5, num_bessel=8, - num_polynomial_cutoff=6, + num_polynomial_cutoff=5, max_ell=2, interaction_cls=modules.interaction_classes[ "RealAgnosticResidualInteractionBlock" @@ -63,19 +65,20 @@ def test_mace(): interaction_cls_first=modules.interaction_classes[ "RealAgnosticResidualInteractionBlock" ], - num_interactions=5, + num_interactions=2, num_elements=2, - hidden_irreps=o3.Irreps("32x0e + 32x1o"), + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), MLP_irreps=o3.Irreps("16x0e"), gate=torch.nn.functional.silu, atomic_energies=atomic_energies, - avg_num_neighbors=8, + avg_num_neighbors=3, atomic_numbers=table.zs, correlation=3, ) - model = modules.MACE(**model_config) - model_compiled = jit.compile(model) + +@pytest.fixture +def data_batch_1() -> dict: atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) atomic_data2 = data.AtomicData.from_config( config_rotated, z_table=table, cutoff=3.0 @@ -88,8 +91,37 @@ def test_mace(): drop_last=False, ) batch = next(iter(data_loader)) - output1 = model(batch.to_dict(), training=True) - output2 = model_compiled(batch.to_dict(), training=True) + return batch.to_dict() + + +def test_mace(data_batch_1): + # Create MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=5, + num_elements=2, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=8, + atomic_numbers=table.zs, + correlation=3, + ) + model = modules.MACE(**model_config) + model_compiled = jit.compile(model) + + output1 = model(data_batch_1, training=True) + output2 = model_compiled(data_batch_1, training=True) assert torch.allclose(output1["energy"][0], output2["energy"][0]) assert torch.allclose(output2["energy"][0], output2["energy"][1]) @@ -144,30 +176,9 @@ def test_dipole_mace(): ) -def test_energy_dipole_mace(): +def test_energy_dipole_mace(dipole_model_config): # create dipole MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=5, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - ) - model = modules.EnergyDipolesMACE(**model_config) + model = modules.EnergyDipolesMACE(**dipole_model_config) atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) atomic_data2 = data.AtomicData.from_config( @@ -194,3 +205,38 @@ def test_energy_dipole_mace(): np.array(rot @ output["dipole"][0].detach().numpy()), output["dipole"][1].detach().numpy(), ) + + +def test_scale_shift_dipole_mace(dipole_model_config, data_batch_1): + dipole_model_config.update({"atomic_inter_scale": 2.0, "atomic_inter_shift": 1.0}) + + # create dipole MACE model + model = modules.ScaleShiftEnergyDipoleMACE(**dipole_model_config) + + output = model( + data_batch_1, + training=True, + ) + + for key in ["energy", "node_energy", "forces", "dipole", "atomic_dipoles"]: + assert key in output + + +def test_scaled_and_shifted(dipole_model_config, data_batch_1): + dipole_model_config.update({"atomic_inter_scale": 2.0, "atomic_inter_shift": 1.0}) + + # create dipole MACE model + model = modules.ScaleShiftEnergyDipoleMACE(**dipole_model_config) + + output_scale_shift = model( + data_batch_1, + training=True, + ) + + # change the shift now + model.scale_shift.shift += 0.5 + output_different_scale = modules.EnergyDipolesMACE.forward(model, data_batch_1) + assert torch.allclose( + output_different_scale["node_energy"] - 0.5, + output_scale_shift["node_energy"], + )