diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index f7409f2f..73cb2d1f 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -36,6 +36,7 @@ EnergyDipolesMACE, ScaleShiftBOTNet, ScaleShiftMACE, + ScaleShiftEnergyDipoleMACE, ) from .radial import BesselBasis, PolynomialCutoff from .symmetric_contraction import SymmetricContraction diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 382831b9..fa26c0f6 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -66,12 +66,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [... @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + def __init__(self, irreps_in: o3.Irreps): super().__init__() - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@compile_mode("script") +class LinearDipoleOnlyReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps): + super().__init__() + self.irreps_out = o3.Irreps("1x1o") self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] @@ -100,24 +108,38 @@ def __init__( [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] ) irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) - self.equivariant_nonlin = nn.Gate( + self.non_linearity = nn.Gate( irreps_scalars=irreps_scalars, act_scalars=[gate for _, ir in irreps_scalars], irreps_gates=irreps_gates, act_gates=[gate] * len(irreps_gates), irreps_gated=irreps_gated, ) - self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() + self.irreps_nonlin = self.non_linearity.irreps_in.simplify() self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) self.linear_2 = o3.Linear( irreps_in=self.hidden_irreps, irreps_out=self.irreps_out ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.equivariant_nonlin(self.linear_1(x)) + x = self.non_linearity(self.linear_1(x)) return self.linear_2(x) # [n_nodes, 1] +@compile_mode("script") +class NonLinearDipoleOnlyReadoutBlock(NonLinearDipoleReadoutBlock): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Callable, + ): + # fixme: use more reasonable inheritance + super().__init__( + irreps_in=irreps_in, MLP_irreps=MLP_irreps, gate=gate, dipole_only=True + ) + + @compile_mode("script") class AtomicEnergiesBlock(torch.nn.Module): atomic_energies: torch.Tensor @@ -258,7 +280,8 @@ def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): def forward( self, - sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded + sender_or_receiver_node_attrs: torch.Tensor, + # assumes that the node attributes are one-hot encoded edge_feats: torch.Tensor, ): return torch.einsum( diff --git a/mace/modules/core_models.py b/mace/modules/core_models.py new file mode 100644 index 00000000..61b9a42a --- /dev/null +++ b/mace/modules/core_models.py @@ -0,0 +1,378 @@ +"""Core model classes, and mixins for various quantities of interest + +Notes +----- +`MaceCoreModel` is the backbone of MACE models, stores all needed blocks and allows +for customisable readout shapes. + +Mixin classes are defined for calculating quantities: +- EnergyModelMixin: inherit from if you want to calculate energies +- ScaleShiftEnergyModelMixin: energies with scaling and shifting +- DipoleModelMixin: dipole learning + +""" +from typing import Type, List, Optional, Callable, Dict, Tuple + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.modules import ( + LinearReadoutBlock, + NonLinearReadoutBlock, + InteractionBlock, + LinearNodeEmbeddingBlock, + RadialEmbeddingBlock, + EquivariantProductBasisBlock, + AtomicEnergiesBlock, + ScaleShiftBlock, +) +from mace.modules.utils import get_edge_vectors_and_lengths, compute_fixed_charge_dipole +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class MaceCoreModel(torch.nn.Module): + """Core model for all MACE models + + Includes the following + - graph parameters: r_max, elements, etc. + - embeddings (node, edge) + - readout blocks (subclasses can change settings of these) + + """ + + _LINEAR_READOUT_CLASS = LinearReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock + + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + radial_MLP: Optional[List[int]] = None, + **kwargs, + ): + """Core MACE model + + Parameters + ---------- + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + """ + super().__init__(**kwargs) + + # Main Buffers + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + + # Interactions and readout + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = self._last_layer_irreps(hidden_irreps) + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + self._NONLINEAR_READOUT_CLASS(hidden_irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(self._LINEAR_READOUT_CLASS(hidden_irreps)) + + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """Irreps to use in the last layer - used for initialisation of subclasses + + core model: drops highest l irreps, unless it's scalar only + + """ + if len(hidden_irreps) == 1: + return hidden_irreps + return o3.Irreps(str(hidden_irreps[:-1])) + + def _calculate_layer_interactions( + self, + data: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """Calculate the layer interactions - used within forward pass + + Parameters + ---------- + data + + Returns + ------- + layer_outputs + shape: [n_nodes, len(ouptuts)] + + """ + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + # Interactions + layer_outputs = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + + layer_outputs.append(readout(node_feats).squeeze(-1)) # [n_nodes, ] + + return torch.sum(torch.stack(layer_outputs, dim=0), dim=0) + + +@compile_mode("script") +class EnergyModelMixin(torch.nn.Module): + """Mixin class for energy models + + Supplies: + - e0: atomic energy block + + """ + + def __init__(self, atomic_energies: np.ndarray, **kwargs): + super().__init__(**kwargs) + + # Energy calculation specific bits + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + def _calc_energy( + self, + data: Dict[str, torch.Tensor], + layer_output_energies: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_graphs = data["ptr"].numel() - 1 + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + + # Sum over energy contributions + node_energy = node_e0 + layer_output_energies + total_energy = scatter_sum( + src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + return total_energy, node_energy + + def _calculate_e0(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: + """e0 calculation, element-wise energy shift + + Notes + ----- + This is really a lookup from the e0 list we have for each node's element. + + """ + return self.atomic_energies_fn(data["node_attrs"]) + + +@compile_mode("script") +class ScaleShiftEnergyModelMixin(EnergyModelMixin): + """Scaled and shifted energy model""" + + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + """ + + Parameters + ---------- + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + """ + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def _calc_energy( + self, + data: Dict[str, torch.Tensor], + layer_output_energies: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return super()._calc_energy(data, self.scale_shift(layer_output_energies)) + + +@compile_mode("script") +class DipoleModelMixin(torch.nn.Module): + """Mixin class for dipole models + + Supplies: + - function to calculate total dipole, including fixed charge baseline + + """ + + @staticmethod + def _calc_total_dipole( + data: Dict[str, torch.Tensor], + atomic_dipoles: torch.Tensor, + ) -> torch.Tensor: + """Calculates total dipoles - adding fixed charge baseline + + Parameters + ---------- + data + atomic_dipoles + corresponding output of the interaction layers + + Returns + ------- + total_dipole + shape: [n_graphs,3] + + """ + num_graphs = data["ptr"].numel() - 1 + + # Sum over dipole contributions + baseline_dipole = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + total_dipole += baseline_dipole + + return total_dipole # [n_graphs,3] diff --git a/mace/modules/models.py b/mace/modules/models.py index 736cf1cb..2d047f02 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -1,8 +1,14 @@ -########################################################################################### +######################################################################################## # Implementation of MACE models and other models based E(3)-Equivariant MPNNs # Authors: Ilyes Batatia, Gregor Simm # This program is distributed under the MIT License (see MIT.md) -########################################################################################### +######################################################################################## +"""Models + +These are the high-level models to be exposed. Internal workings are kept in the class +hierarchy. + +""" from typing import Any, Callable, Dict, List, Optional, Type @@ -13,21 +19,26 @@ from mace.data import AtomicData from mace.tools.scatter import scatter_sum - from .blocks import ( AtomicEnergiesBlock, - EquivariantProductBasisBlock, InteractionBlock, LinearDipoleReadoutBlock, + LinearDipoleOnlyReadoutBlock, LinearNodeEmbeddingBlock, LinearReadoutBlock, NonLinearDipoleReadoutBlock, + NonLinearDipoleOnlyReadoutBlock, NonLinearReadoutBlock, RadialEmbeddingBlock, ScaleShiftBlock, ) +from .core_models import ( + MaceCoreModel, + EnergyModelMixin, + ScaleShiftEnergyModelMixin, + DipoleModelMixin, +) from .utils import ( - compute_fixed_charge_dipole, compute_forces, get_edge_vectors_and_lengths, get_outputs, @@ -36,120 +47,59 @@ @compile_mode("script") -class MACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - radial_MLP: Optional[List[int]] = None, - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) +class MACE(MaceCoreModel, EnergyModelMixin): + """MACE model - energy only""" + + _LINEAR_READOUT_CLASS = LinearReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearReadoutBlock + + def __init__(self, **kwargs): + """MACE energy model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + + **kwargs + """ + super().__init__(**kwargs) - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """energy: Select only scalars for last layer""" + return o3.Irreps(str(hidden_irreps[0])) def forward( self, @@ -160,6 +110,32 @@ def forward( compute_stress: bool = False, compute_displacement: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: + """Forward pass - evaluation of model + + Parameters + ---------- + data + input graphs, as dictionary + training + If we are in training mode. Computational graphs are retained, for reuse. + compute_force + Compute forces on all nodes + compute_virials, compute_stress + Compute stress and virial, any of these two triggers calculation of both + compute_displacement + Compute symmetric displacements + + Returns + ------- + results + calculated results + - `energy`: total energy on each graph + - `node_energy`: energy of each node, i.e. local energy on each atom + - `forces`: negative gradient of total energy on each atom (node) + - `virials` & `stress`: same thing expressed differently + - `displacement`: symmetric displacement of cell used for stress + + """ # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) @@ -183,54 +159,12 @@ def forward( batch=data["batch"], ) - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], + # Calculate energies + total_energy, node_energy = self._calc_energy( + data, self._calculate_layer_interactions(data) ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - # Interactions - energies = [e0] - node_energies_list = [node_e0] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_energies_list.append(node_energies) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - print("contributions", contributions.shape) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - # Outputs + # Calculate derivatives if needed forces, virials, stress = get_outputs( energy=total_energy, positions=data["positions"], @@ -245,7 +179,6 @@ def forward( return { "energy": total_energy, "node_energy": node_energy, - "contributions": contributions, "forces": forces, "virials": virials, "stress": stress, @@ -254,102 +187,288 @@ def forward( @compile_mode("script") -class ScaleShiftMACE(MACE): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): +class ScaleShiftMACE(MACE, ScaleShiftEnergyModelMixin): + """Scaled and Shifted MACE energy model + + Same as MACE model, but allows for constant shift and rescaling of energy + """ + + def __init__(self, **kwargs): + """Scaled & Shifted MACE energy model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + """ super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) + + +@compile_mode("script") +class AtomicDipolesMACE(MaceCoreModel, DipoleModelMixin): + """MACE model for dipoles only""" + + _LINEAR_READOUT_CLASS = LinearDipoleOnlyReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearDipoleOnlyReadoutBlock + + def __init__(self, hidden_irreps: o3.Irreps, atomic_energies=None, **kwargs): + """MACE dipole only model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + **kwargs + + Notes + ----- + Requires at least l=1 features to be used for representing dipoles. + `atomic_energies` parameter added for compatibility, but it's ignored + + """ + assert ( + hidden_irreps.lmax >= 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + super().__init__(hidden_irreps=hidden_irreps, **kwargs) + + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """dipole: Select only l=1 vectors for last layer""" + return o3.Irreps(str(hidden_irreps[1])) def forward( self, data: Dict[str, torch.Tensor], training: bool = False, - compute_force: bool = True, + compute_force: bool = False, compute_virials: bool = False, compute_stress: bool = False, compute_displacement: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: + """Forward pass - evaluation of model + + Parameters + ---------- + data + input graphs, as dictionary + training + If we are in training mode. Computational graphs are retained, for reuse. + compute_force, compute_virials, compute_stress, compute_displacement + compatibility only, RAISES ERROR IF True + + Returns + ------- + results + calculated results + - `dipole`: dipole of the graph, includes atomic and fixed charge dipole + components as well + - `atomic_dipoles`: dipoles of each atom + + """ + + # dipoles and virials / stress not supported simultaneously + error_msg = "AtomicDipolesMACE does not support energy & its derivatives" + assert not compute_force, error_msg + assert not compute_virials, error_msg + assert not compute_stress, error_msg + assert not compute_displacement, error_msg + # Setup - data["positions"].requires_grad_(True) - data["node_attrs"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) + data["positions"].requires_grad = True - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + # Evaluate layer outputs + atomic_dipoles = self._calculate_layer_interactions(data) - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) + return { + "dipole": self._calc_total_dipole(data, atomic_dipoles), + "atomic_dipoles": atomic_dipoles, + } - # Interactions - node_es_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] - ) - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } - # Sum over interactions - node_inter_es = torch.sum( - torch.stack(node_es_list, dim=0), dim=0 - ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) +@compile_mode("script") +class EnergyDipolesMACE(MACE, DipoleModelMixin): + """MACE model for Energy & Dipoles""" + + _LINEAR_READOUT_CLASS = LinearDipoleReadoutBlock + _NONLINEAR_READOUT_CLASS = NonLinearDipoleReadoutBlock + + def __init__(self, hidden_irreps: o3.Irreps, **kwargs): + """MACE energy & dipole model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + + **kwargs + + Notes + ----- + Requires at least l=1 features to be used for representing dipoles. + `atomic_energies` parameter added for compatibility, but it's ignored + + """ + assert ( + hidden_irreps.lmax >= 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + super().__init__(hidden_irreps=hidden_irreps, **kwargs) + + @staticmethod + def _last_layer_irreps(hidden_irreps) -> o3.Irreps: + """energy & dipole: Select scalars and l=1 vectors for last layer""" + return o3.Irreps(str(hidden_irreps[:2])) - # Sum over nodes in graph - inter_e = scatter_sum( - src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + """Forward pass - evaluation of model + + Parameters + ---------- + data + input graphs, as dictionary + training + If we are in training mode. Computational graphs are retained, for reuse. + compute_force + Compute forces on all nodes + compute_virials, compute_stress, compute_displacement + compatibility only, RAISES ERROR IF True + + Returns + ------- + results + calculated results + - `energy`: total energy on each graph + - `node_energy`: energy of each node, i.e. local energy on each atom + - `forces`: negative gradient of total energy on each atom (node) + - `dipole`: dipole of the graph, includes atomic and fixed charge dipole + components as well + - `atomic_dipoles`: dipoles of each atom + + """ + # dipoles and virials / stress not supported simultaneously + error_msg = "dipoles and virials / stress not supported simultaneously" + assert not compute_virials, error_msg + assert not compute_stress, error_msg + assert not compute_displacement, error_msg - # Add E_0 and (scaled) interaction energy - total_energy = e0 + inter_e - node_energy = node_e0 + node_inter_es + # Setup + data["positions"].requires_grad = True - forces, virials, stress = get_outputs( - energy=inter_e, + # Evaluate layer outputs & unpack + layer_outputs = self._calculate_layer_interactions(data) + interaction_energies = layer_outputs[:, 0] + atomic_dipoles = layer_outputs[:, 1:] + + # Calculate energies + total_energy, node_energy = self._calc_energy(data, interaction_energies) + + # Calculate derivatives if needed + forces, _, _ = get_outputs( + energy=total_energy, positions=data["positions"], - displacement=displacement, + displacement=None, cell=data["cell"], training=training, compute_force=compute_force, @@ -357,17 +476,75 @@ def forward( compute_stress=compute_stress, ) - output = { + return { "energy": total_energy, "node_energy": node_energy, - "interaction_energy": inter_e, "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, + "dipole": self._calc_total_dipole(data, atomic_dipoles), + "atomic_dipoles": atomic_dipoles, } - return output + +@compile_mode("script") +class ScaleShiftEnergyDipoleMACE(EnergyDipolesMACE, ScaleShiftEnergyModelMixin): + """MACE model with Scaled and Shifted Energy, & Dipoles + + Same as EnergyDipolesMACE model, but allows for constant shift and rescaling of + energy + """ + + def __init__(self, **kwargs): + """MACE energy & dipole model + + Parameters + ---------- + # ------- core model parameters + r_max + cutoff of radial embedding applied to individual atoms + num_bessel + number of Bessel functions to be used for radial embedding + num_polynomial_cutoff + max_ell + l_max, maximum channels for spherical harmonics + interaction_cls + class to be used for interactions blocks + interaction_cls_first + class to be used for the first layer's interaction block + num_interactions + number of interaction layers to use + num_elements + redundant parameter for the number of elements + hidden_irreps + hidden irreducible representations, basically the size of the layer features + and hence direct control on the size of the model + MLP_irreps + avg_num_neighbors + atomic_numbers + atomic numbers of elements the model supports + order & length to agree with atomic_energies + correlation + gate + non-linearity for non-linear readouts + radial_MLP + + # ------- energy model specific parameters + atomic_energies + energy shift (e0) of elements + order & length to agree with atomic_numbers + atomic_inter_scale + scale of interaction energy + atomic_inter_shift + constant shift of interaction energy (per atom) + + **kwargs + + Notes + ----- + Requires at least l=1 features to be used for representing dipoles. + `atomic_energies` parameter added for compatibility, but it's ignored + + """ + super().__init__(**kwargs) class BOTNet(torch.nn.Module): @@ -387,8 +564,9 @@ def __init__( gate: Optional[Callable], avg_num_neighbors: float, atomic_numbers: List[int], + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.r_max = r_max self.atomic_numbers = atomic_numbers # Embedding @@ -448,15 +626,15 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: data.positions.requires_grad = True # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"]) e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] ) # [n_graphs,] # Embeddings - node_feats = self.node_embedding(data.node_attrs) + node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) @@ -465,15 +643,18 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: energies = [e0] for interaction, readout in zip(self.interactions, self.readouts): node_feats = interaction( - node_attrs=data.node_attrs, + node_attrs=data["node_attrs"], node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, - edge_index=data.edge_index, + edge_index=data["edge_index"], ) node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] energy = scatter_sum( - src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_energies, + index=data["batch"], + dim=-1, + dim_size=data["num_graphs"], ) # [n_graphs,] energies.append(energy) @@ -509,15 +690,15 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: data.positions.requires_grad = True # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"]) e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_e0, index=data["batch"], dim=-1, dim_size=data["num_graphs"] ) # [n_graphs,] # Embeddings - node_feats = self.node_embedding(data.node_attrs) + node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + positions=data.positions, edge_index=data["edge_index"], shifts=data.shifts ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) @@ -526,11 +707,11 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: node_es_list = [] for interaction, readout in zip(self.interactions, self.readouts): node_feats = interaction( - node_attrs=data.node_attrs, + node_attrs=data["node_attrs"], node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, - edge_index=data.edge_index, + edge_index=data["edge_index"], ) node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } @@ -543,7 +724,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Sum over nodes in graph inter_e = scatter_sum( - src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs + src=node_inter_es, index=data["batch"], dim=-1, dim_size=data["num_graphs"] ) # [n_graphs,] # Add E_0 and (scaled) interaction energy @@ -557,408 +738,3 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: } return output - - -class AtomicDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[ - None - ], # Just here to make it compatible with energy models, MUST be None - ): - super().__init__() - assert atomic_energies is None - self.r_max = r_max - self.atomic_numbers = atomic_numbers - - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - - # Interactions and readouts - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[1] - ) # Select only l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=inter.irreps_out, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=True - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) - ) - - def forward( - self, - data: AtomicData, - training=False, - compute_force: bool = False, - compute_virials: bool = False, - compute_stress: bool = False, - ) -> Dict[str, Any]: - assert compute_force is False - assert compute_virials is False - assert compute_stress is False - # Setup - data.positions.requires_grad = True - if not training: - for p in self.parameters(): - p.requires_grad = False - else: - for p in self.parameters(): - p.requires_grad = True - - # Embeddings - node_feats = self.node_embedding(data.node_attrs) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - - # Interactions - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data.node_attrs, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data.edge_index, - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data.node_attrs - ) - node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - dipoles.append(node_dipoles) - - # Compute the dipoles - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data.batch.unsqueeze(-1), - dim=0, - dim_size=data.num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data.charges, - positions=data.positions, - batch=data.batch, - num_graphs=data.num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - output = { - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - - return output - - -class EnergyDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[np.ndarray], - ): - super().__init__() - self.r_max = r_max - self.atomic_numbers = atomic_numbers - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[:2] - ) # Select scalars and l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=inter.irreps_out, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=False - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) - ) - - def forward( - self, - data: AtomicData, - training=False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - ) -> Dict[str, Any]: - # dipoles and virials / stress not supported simultaneously - assert compute_virials is False - assert compute_stress is False - # Setup - data.positions.requires_grad = True - if not training: - for p in self.parameters(): - p.requires_grad = False - else: - for p in self.parameters(): - p.requires_grad = True - - # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) - e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data.node_attrs) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) - - # Interactions - energies = [e0] - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data.node_attrs, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data.edge_index, - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data.node_attrs - ) - node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] - # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - node_energies = node_out[:, 0] - energy = scatter_sum( - src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - energies.append(energy) - # node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - node_dipoles = node_out[:, 1:] - dipoles.append(node_dipoles) - - # Compute the energies and dipoles - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data.batch.unsqueeze(-1), - dim=0, - dim_size=data.num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data.charges, - positions=data.positions, - batch=data.batch, - num_graphs=data.num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - forces, _, _ = get_outputs( - energy=total_energy, - positions=data.positions, - displacement=None, - cell=data.cell, - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - ) - - output = { - "energy": total_energy, - "contributions": contributions, - "forces": forces, - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - - return output diff --git a/tests/test_models.py b/tests/test_models.py index f39c437f..d42daa71 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch import torch.nn.functional from e3nn import o3 @@ -6,6 +7,7 @@ from scipy.spatial.transform import Rotation as R from mace import data, modules, tools +from mace.modules.models import ScaleShiftEnergyDipoleMACE from mace.tools import torch_geometric torch.set_default_dtype(torch.float64) @@ -50,12 +52,12 @@ atomic_energies = np.array([1.0, 3.0], dtype=float) -def test_mace(): - # Create MACE model - model_config = dict( +@pytest.fixture +def dipole_model_config() -> dict: + return dict( r_max=5, num_bessel=8, - num_polynomial_cutoff=6, + num_polynomial_cutoff=5, max_ell=2, interaction_cls=modules.interaction_classes[ "RealAgnosticResidualInteractionBlock" @@ -63,19 +65,20 @@ def test_mace(): interaction_cls_first=modules.interaction_classes[ "RealAgnosticResidualInteractionBlock" ], - num_interactions=5, + num_interactions=2, num_elements=2, - hidden_irreps=o3.Irreps("32x0e + 32x1o"), + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), MLP_irreps=o3.Irreps("16x0e"), gate=torch.nn.functional.silu, atomic_energies=atomic_energies, - avg_num_neighbors=8, + avg_num_neighbors=3, atomic_numbers=table.zs, correlation=3, ) - model = modules.MACE(**model_config) - model_compiled = jit.compile(model) + +@pytest.fixture +def data_batch_1() -> dict: atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) atomic_data2 = data.AtomicData.from_config( config_rotated, z_table=table, cutoff=3.0 @@ -88,8 +91,37 @@ def test_mace(): drop_last=False, ) batch = next(iter(data_loader)) - output1 = model(batch.to_dict(), training=True) - output2 = model_compiled(batch.to_dict(), training=True) + return batch.to_dict() + + +def test_mace(data_batch_1): + # Create MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=5, + num_elements=2, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=8, + atomic_numbers=table.zs, + correlation=3, + ) + model = modules.MACE(**model_config) + model_compiled = jit.compile(model) + + output1 = model(data_batch_1, training=True) + output2 = model_compiled(data_batch_1, training=True) assert torch.allclose(output1["energy"][0], output2["energy"][0]) assert torch.allclose(output2["energy"][0], output2["energy"][1]) @@ -144,30 +176,9 @@ def test_dipole_mace(): ) -def test_energy_dipole_mace(): +def test_energy_dipole_mace(dipole_model_config): # create dipole MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=5, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - ) - model = modules.EnergyDipolesMACE(**model_config) + model = modules.EnergyDipolesMACE(**dipole_model_config) atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) atomic_data2 = data.AtomicData.from_config( @@ -182,7 +193,7 @@ def test_energy_dipole_mace(): ) batch = next(iter(data_loader)) output = model( - batch, + batch.to_dict(), training=True, ) # sanity check of dipoles being the right shape @@ -194,3 +205,38 @@ def test_energy_dipole_mace(): np.array(rot @ output["dipole"][0].detach().numpy()), output["dipole"][1].detach().numpy(), ) + + +def test_scale_shift_dipole_mace(dipole_model_config, data_batch_1): + dipole_model_config.update({"atomic_inter_scale": 2.0, "atomic_inter_shift": 1.0}) + + # create dipole MACE model + model = modules.ScaleShiftEnergyDipoleMACE(**dipole_model_config) + + output = model( + data_batch_1, + training=True, + ) + + for key in ["energy", "node_energy", "forces", "dipole", "atomic_dipoles"]: + assert key in output + + +def test_scaled_and_shifted(dipole_model_config, data_batch_1): + dipole_model_config.update({"atomic_inter_scale": 2.0, "atomic_inter_shift": 1.0}) + + # create dipole MACE model + model = modules.ScaleShiftEnergyDipoleMACE(**dipole_model_config) + + output_scale_shift = model( + data_batch_1, + training=True, + ) + + # change the shift now + model.scale_shift.shift += 0.5 + output_different_scale = modules.EnergyDipolesMACE.forward(model, data_batch_1) + assert torch.allclose( + output_different_scale["node_energy"] - 0.5, + output_scale_shift["node_energy"], + )