diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d78624bb..6f8c2daa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,5 +55,6 @@ repos: '--disable=cell-var-from-loop', '--disable=duplicate-code', '--disable=use-dict-literal', + '--max-module-lines=1500', ] exclude: *exclude_files \ No newline at end of file diff --git a/mace/__init__.py b/mace/__init__.py index 9226fe7e..e9c9ef48 100644 --- a/mace/__init__.py +++ b/mace/__init__.py @@ -1 +1 @@ -from .__version__ import __version__ +from .__version__ import __version__ \ No newline at end of file diff --git a/mace/cli/convert_cueq_e3nn.py b/mace/cli/convert_cueq_e3nn.py new file mode 100644 index 00000000..57b2808e --- /dev/null +++ b/mace/cli/convert_cueq_e3nn.py @@ -0,0 +1,206 @@ +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys() -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.1.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(2) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + return [[0, max_L], [1, 0]] + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, +): + """Transfer symmetric contraction weights from CuEq to E3nn format""" + kmax_pairs = get_kmax_pairs(max_L, correlation) + logging.info( + f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}" + ) + + for i, kmax in kmax_pairs: + # Get the combined weight tensor from source + wm = source_dict[f"products.{i}.symmetric_contractions.weight"] + + # Get split sizes based on target dimensions + splits = [] + for k in range(kmax + 1): + for suffix in ["_max", ".0", ".1"]: + key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" + target_shape = target_dict[key].shape + splits.append(target_shape[1]) + + # Split the weights using the calculated sizes + weights_split = torch.split(wm, splits, dim=1) + + # Assign back to target dictionary + idx = 0 + for k in range(kmax + 1): + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" + ] = weights_split[idx] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" + ] = weights_split[idx + 1] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" + ] = weights_split[idx + 2] + idx += 3 + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, +): + """Transfer weights from CuEq to E3nn format""" + # Get state dicts + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys() + logging.info("Transferring main weights...") + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + logging.info("Transferring symmetric contractions...") + transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + + # Transfer remaining matching keys + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + + if remaining_keys: + logging.info( + f"Found {len(remaining_keys)} additional matching keys to transfer" + ) + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run(input_model, output_model="_e3nn.model", device="cuda", return_model=True): + # Setup logging + logging.basicConfig(level=logging.INFO) + + # Load CuEq model + logging.info(f"Loading CuEq model from {input_model}") + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + + # Extract configuration + logging.info("Extracting model configuration") + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + logging.info(f"Extracted max_L={max_L}, correlation={correlation}") + + # Remove CuEq config + config.pop("cueq_config", None) + + # Create new model without CuEq config + logging.info("Creating new model without CuEq settings") + target_model = source_model.__class__(**config) + + # Transfer weights with proper remapping + logging.info("Transferring weights with remapping...") + transfer_weights(source_model, target_model, max_L, correlation) + + if return_model: + return target_model + + # Save model + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.info(f"Saving E3nn model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input CuEq model") + parser.add_argument( + "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py new file mode 100644 index 00000000..715d7bbf --- /dev/null +++ b/mace/cli/convert_e3nn_cueq.py @@ -0,0 +1,201 @@ +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys() -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.1.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(2) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + return [[0, max_L], [1, 0]] + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, +): + """Transfer symmetric contraction weights""" + kmax_pairs = get_kmax_pairs(max_L, correlation) + logging.info( + f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}" + ) + + for i, kmax in kmax_pairs: + wm = torch.concatenate( + [ + source_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" + ] + for k in range(kmax + 1) + for j in ["_max", ".0", ".1"] + ], + dim=1, + ) # .float() + target_dict[f"products.{i}.symmetric_contractions.weight"] = wm + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, +): + """Transfer weights with proper remapping""" + # Get source state dict + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys() + logging.info("Transferring main weights...") + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + logging.info("Transferring symmetric contractions...") + transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + + if remaining_keys: + logging.info( + f"Found {len(remaining_keys)} additional matching keys to transfer" + ) + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run( + input_model, + output_model="_cueq.model", + device="cuda", + return_model=True, +): + # Setup logging + logging.basicConfig(level=logging.INFO) + + # Load original model + logging.info(f"Loading model from {input_model}") + # check if input_model is a path or a model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + + # Extract configuration + logging.info("Extracting model configuration") + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + logging.info(f"Extracted max_L={max_L}, correlation={correlation}") + + # Add cuequivariance config + config["cueq_config"] = CuEquivarianceConfig( + enabled=True, + layout="mul_ir", + group="O3_e3nn", + optimize_all=True, + ) + + # Create new model with cuequivariance config + logging.info("Creating new model with cuequivariance settings") + target_model = source_model.__class__(**config) + + # Transfer weights with proper remapping + logging.info("Transferring weights with remapping...") + transfer_weights(source_model, target_model, max_L, correlation) + + if return_model: + return target_model + + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.info(f"Saving CuEq model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input MACE model") + parser.add_argument( + "--output_model", + help="Path to output cuequivariance model", + default="cueq_model.pt", + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3813b055..4196e083 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -54,6 +54,8 @@ from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.tables_utils import create_error_table from mace.tools.utils import AtomicNumberTable +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq def main() -> None: @@ -600,7 +602,10 @@ def run(args: argparse.Namespace) -> None: if args.wandb: setup_wandb(args) - + if args.enable_cueq: + logging.info("Converting model to CUEQ for accelerated training") + assert args.model in ["MACE", "ScaleShiftMACE"], "Model must be MACE or ScaleShiftMACE" + model = run_e3nn_to_cueq(model) if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) else: @@ -752,6 +757,8 @@ def run(args: argparse.Namespace) -> None: if rank == 0: # Save entire model + if args.enable_cueq: + model = run_cueq_to_e3nn(model) if swa_eval: model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") else: diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 0db3b02e..7bc3561f 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -12,6 +12,13 @@ from e3nn import nn, o3 from e3nn.util.jit import compile_mode +from mace.modules.wrapper_ops import ( + CuEquivarianceConfig, + FullyConnectedTensorProduct, + Linear, + SymmetricContractionWrapper, + TensorProduct, +) from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum @@ -29,14 +36,20 @@ PolynomialCutoff, SoftTransform, ) -from .symmetric_contraction import SymmetricContraction @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config + ) def forward( self, @@ -47,9 +60,16 @@ def forward( @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps = o3.Irreps("0e"), + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config + ) def forward( self, @@ -69,13 +89,18 @@ def __init__( gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), num_heads: int = 1, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps self.num_heads = num_heads - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config + ) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config + ) def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None @@ -89,13 +114,20 @@ def forward( @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + def __init__( + self, + irreps_in: o3.Irreps, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() if dipole_only: self.irreps_out = o3.Irreps("1x1o") else: self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -109,6 +141,7 @@ def __init__( MLP_irreps: o3.Irreps, gate: Callable, dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps @@ -131,9 +164,13 @@ def __init__( irreps_gated=irreps_gated, ) self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) - self.linear_2 = o3.Linear( - irreps_in=self.hidden_irreps, irreps_out=self.irreps_out + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config + ) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, + irreps_out=self.irreps_out, + cueq_config=cueq_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] @@ -218,22 +255,25 @@ def __init__( correlation: int, use_sc: bool = True, num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.use_sc = use_sc - self.symmetric_contractions = SymmetricContraction( + self.symmetric_contractions = SymmetricContractionWrapper( irreps_in=node_feats_irreps, irreps_out=target_irreps, correlation=correlation, num_elements=num_elements, + cueq_config=cueq_config, ) # Update linear - self.linear = o3.Linear( + self.linear = Linear( target_irreps, target_irreps, internal_weights=True, shared_weights=True, + cueq_config=cueq_config, ) def forward( @@ -260,6 +300,7 @@ def __init__( hidden_irreps: o3.Irreps, avg_num_neighbors: float, radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.node_attrs_irreps = node_attrs_irreps @@ -272,6 +313,7 @@ def __init__( if radial_MLP is None: radial_MLP = [64, 64, 64] self.radial_MLP = radial_MLP + self.cueq_config = cueq_config self._setup() @@ -325,23 +367,29 @@ def __repr__(self): @compile_mode("script") class ResidualElementDependentInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.linear_up = o3.Linear( + if not hasattr(self, "cueq_config"): + self.cueq_config = None + + # First linear + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) self.conv_tp_weights = TensorProductWeightsBlock( num_elements=self.node_attrs_irreps.num_irreps, @@ -353,13 +401,20 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -389,23 +444,27 @@ def forward( @compile_mode("script") class AgnosticNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.linear_up = o3.Linear( + if not hasattr(self, "cueq_config"): + self.cueq_config = None + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -419,13 +478,20 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -455,24 +521,28 @@ def forward( @compile_mode("script") class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -486,13 +556,20 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -523,12 +600,15 @@ def forward( @compile_mode("script") class RealAgnosticInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -536,13 +616,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -555,15 +636,22 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -595,12 +683,15 @@ def forward( @compile_mode("script") class RealAgnosticResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -608,13 +699,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -627,15 +719,22 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -667,12 +766,15 @@ def forward( @compile_mode("script") class RealAgnosticDensityInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -680,13 +782,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -699,15 +802,21 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) # Density normalization self.density_fn = nn.FullyConnectedNet( @@ -718,7 +827,7 @@ def _setup(self) -> None: torch.nn.functional.silu, ) # Reshape - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -754,12 +863,16 @@ def forward( @compile_mode("script") class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -767,13 +880,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -786,15 +900,21 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) # Density normalization self.density_fn = nn.FullyConnectedNet( @@ -806,7 +926,7 @@ def _setup(self) -> None: ) # Reshape - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -842,13 +962,16 @@ def forward( @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None self.node_feats_down_irreps = o3.Irreps("64x0e") # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -856,21 +979,23 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights - self.linear_down = o3.Linear( + self.linear_down = Linear( self.node_feats_irreps, self.node_feats_down_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) input_dim = ( self.edge_feats_irreps.num_irreps @@ -884,17 +1009,20 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( + self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) # Skip connection. - self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) + self.skip_linear = Linear( + self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config + ) def forward( self, diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index b0960193..2e79c0ab 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -4,12 +4,14 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from e3nn import o3 from e3nn.util.jit import compile_mode +from mace.modules.wrapper_ops import CuEquivarianceConfig + # Based on mir-group/nequip def tp_out_irreps_with_instructions( @@ -64,9 +66,12 @@ def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: @compile_mode("script") class reshape_irreps(torch.nn.Module): - def __init__(self, irreps: o3.Irreps) -> None: + def __init__( + self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None + ) -> None: super().__init__() self.irreps = o3.Irreps(irreps) + self.cueq_config = cueq_config self.dims = [] self.muls = [] for mul, ir in self.irreps: @@ -81,8 +86,19 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: for mul, d in zip(self.muls, self.dims): field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d - field = field.reshape(batch, mul, d) + if hasattr(self, "cueq_config") and self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) + else: + field = field.reshape(batch, mul, d) out.append(field) + + if hasattr(self, "cueq_config") and self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-2) return torch.cat(out, dim=-1) diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..0e03317e 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -62,6 +62,7 @@ def __init__( radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", heads: Optional[List[str]] = None, + cueq_config: Optional[Dict[str, Any]] = None, ): super().__init__() self.register_buffer( @@ -82,7 +83,9 @@ def __init__( node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + irreps_in=node_attr_irreps, + irreps_out=node_feats_irreps, + cueq_config=cueq_config, ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, @@ -116,6 +119,7 @@ def __init__( hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, + cueq_config=cueq_config, ) self.interactions = torch.nn.ModuleList([inter]) @@ -131,12 +135,15 @@ def __init__( correlation=correlation[0], num_elements=num_elements, use_sc=use_sc_first, + cueq_config=cueq_config, ) self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) ) for i in range(num_interactions - 1): @@ -155,6 +162,7 @@ def __init__( hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, + cueq_config=cueq_config, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( @@ -163,6 +171,7 @@ def __init__( correlation=correlation[i + 1], num_elements=num_elements, use_sc=True, + cueq_config=cueq_config, ) self.products.append(prod) if i == num_interactions - 2: @@ -173,11 +182,14 @@ def __init__( gate, o3.Irreps(f"{len(heads)}x0e"), len(heads), + cueq_config, ) ) else: self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) ) def forward( @@ -471,6 +483,7 @@ def __init__( gate: Optional[Callable], avg_num_neighbors: float, atomic_numbers: List[int], + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.r_max = r_max @@ -675,6 +688,7 @@ def __init__( ], # Just here to make it compatible with energy models, MUST be None radial_type: Optional[str] = "bessel", radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( @@ -876,6 +890,7 @@ def __init__( gate: Optional[Callable], atomic_energies: Optional[np.ndarray], radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py new file mode 100644 index 00000000..437d106c --- /dev/null +++ b/mace/modules/wrapper_ops.py @@ -0,0 +1,254 @@ +""" +Wrapper class for o3.Linear that optionally uses cuet.Linear +""" + +import dataclasses +import itertools +import types +from typing import Iterator, List, Optional + +import numpy as np +import torch +from e3nn import o3 + +from mace.modules.symmetric_contraction import SymmetricContraction + +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +if CUET_AVAILABLE: + + class O3_e3nn(cue.O3): + def __mul__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> Iterator["O3_e3nn"]: + return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] + + @classmethod + def clebsch_gordan( + cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> bool: + rep2 = rep1._from(rep2) + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator["O3_e3nn"]: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + +else: + print( + "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." + ) + + +@dataclasses.dataclass +class CuEquivarianceConfig: + """Configuration for cuequivariance acceleration""" + + enabled: bool = False + layout: str = "mul_ir" # One of: mul_ir, ir_mul + layout_str: str = "mul_ir" + group: str = "O3" + optimize_all: bool = False # Set to True to enable all optimizations + optimize_linear: bool = False + optimize_channelwise: bool = False + optimize_symmetric: bool = False + optimize_fctp: bool = False + + def __post_init__(self): + if self.enabled and CUET_AVAILABLE: + self.layout_str = self.layout + self.layout = getattr(cue, self.layout) + self.group = ( + O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) + ) + + +class Linear: + """Returns either a cuet.Linear or o3.Linear based on config""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear) + ): + instance = cuet.Linear( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + ) + instance._original_forward = instance.forward + def cuet_forward(self, x: torch.Tensor) -> torch.Tensor: + return self._original_forward(x, use_fallback=None) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return o3.Linear( + irreps_in, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class TensorProduct: + """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + instructions: Optional[List] = None, + shared_weights: bool = False, + internal_weights: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_channelwise) + ): + instance = cuet.ChannelWiseTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + instance._original_forward = instance.forward + def cuet_forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + return self._original_forward(x, y, z, use_fallback=None) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return o3.TensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + instructions=instructions, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class FullyConnectedTensorProduct: + """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_fctp) + ): + instance = cuet.FullyConnectedTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + instance._original_forward = instance.forward + def cuet_forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: + return self._original_forward(x, attrs, use_fallback=None) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return o3.FullyConnectedTensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class SymmetricContractionWrapper: + """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_symmetric) + ): + instance = cuet.SymmetricContraction( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout_in=cue.ir_mul, + layout_out=cueq_config.layout, + contraction_degree=correlation, + num_elements=num_elements, + original_mace=True, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), + ) + instance._original_forward = instance.forward + instance.layout = cueq_config.layout + def cuet_forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: + if self.layout == cue.mul_ir: + x = torch.transpose(x, 1, 2) + index_attrs = torch.nonzero(attrs)[:, 1].int() + return self._original_forward( + x.flatten(1), + index_attrs, + use_fallback=None, + ) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return SymmetricContraction( + irreps_in=irreps_in, + irreps_out=irreps_out, + correlation=correlation, + num_elements=num_elements, + ) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index cb4f8ac5..07e02e49 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -660,6 +660,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=check_float_or_none, default=10.0, ) + # option for cuequivariance acceleration + parser.add_argument( + "--enable_cueq", + help="Enable cuequivariance acceleration", + type=str2bool, + default=False, + ) # options for using Weights and Biases for experiment tracking # to install see https://wandb.ai parser.add_argument( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index be96558d..1f1be22d 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -175,6 +175,12 @@ def radial_to_transform(radial): scale = model.scale_shift.scale shift = model.scale_shift.shift + try: + correlation = ( + len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 + ) + except AttributeError: + correlation = model.products[0].symmetric_contractions.contraction_degree config = { "r_max": model.r_max.item(), "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), @@ -200,10 +206,7 @@ def radial_to_transform(radial): "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), "avg_num_neighbors": model.interactions[0].avg_num_neighbors, "atomic_numbers": model.atomic_numbers, - "correlation": len( - model.products[0].symmetric_contractions.contractions[0].weights - ) - + 1, + "correlation": correlation, "radial_type": radial_to_name( model.radial_embedding.bessel_fn.__class__.__name__ ), diff --git a/pyproject.toml b/pyproject.toml index 489bc6e5..c7644f78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,6 @@ ignore-paths = [ "^mace/tools/torch_geometric/.*$", "^mace/tools/scatter.py$", ] + +[tool.pylint.FORMAT] +max-module-lines = 1500 diff --git a/scripts/run_checks.sh b/scripts/run_checks.sh old mode 100755 new mode 100644 diff --git a/setup.cfg b/setup.cfg index 76467fda..ba714d85 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ console_scripts = mace_finetuning = mace.cli.fine_tuning_select:main mace_convert_device = mace.cli.convert_device:main mace_select_head = mace.cli.select_head:main + mace_e3nn_cueq = mace.cli.convert_e3nn_cueq:main + mace_cueq_to_e3nn = mace.cli.convert_cueq_e3nn:main [options.extras_require] wandb = wandb diff --git a/tests/test_cueq.py b/tests/test_cueq.py new file mode 100644 index 00000000..79bacc6c --- /dev/null +++ b/tests/test_cueq.py @@ -0,0 +1,213 @@ +from typing import Any, Dict + +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from e3nn.util import jit +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.tools import torch_geometric + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +torch.set_default_dtype(torch.float64) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +class TestCueq: + @pytest.fixture + def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: + table = tools.AtomicNumberTable([6]) + print("interaction_cls_first", interaction_cls_first) + print("hidden_irreps", hidden_irreps) + return { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": interaction_cls_first, + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": hidden_irreps, + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": torch.tensor([1.0]), + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + + @pytest.fixture + def batch(self, device: str): + from ase import build + + table = tools.AtomicNumberTable([6]) + + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + import numpy as np + + displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) + atoms.positions += displacement + atoms_list = [atoms.repeat((2, 2, 2))] + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=5.0) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + return batch.to(device).to_dict() + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize( + "interaction_cls_first", + [ + modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + modules.interaction_classes["RealAgnosticInteractionBlock"], + modules.interaction_classes["RealAgnosticDensityInteractionBlock"], + ], + ) + @pytest.mark.parametrize( + "hidden_irreps", + [ + # o3.Irreps("32x0e + 32x1o"), + # o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e"), + ], + ) + def test_bidirectional_conversion( + self, + model_config: Dict[str, Any], + batch: Dict[str, torch.Tensor], + device: str, + ): + torch.manual_seed(42) + + # Create original E3nn model + model_e3nn = modules.ScaleShiftMACE(**model_config) + model_e3nn = model_e3nn.to(device) + + # Convert E3nn to CuEq + model_cueq = run_e3nn_to_cueq(model_e3nn) + model_cueq = model_cueq.to(device) + + # Convert CuEq back to E3nn + model_e3nn_back = run_cueq_to_e3nn(model_cueq) + model_e3nn_back = model_e3nn_back.to(device) + + # Test forward pass equivalence + out_e3nn = model_e3nn(batch, training=True) + out_cueq = model_cueq(batch, training=True) + out_e3nn_back = model_e3nn_back(batch, training=True) + + # Check outputs match for both conversions + torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + + # Test backward pass equivalence + loss_e3nn = out_e3nn["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + loss_e3nn_back = out_e3nn_back["energy"].sum() + + loss_e3nn.backward() + loss_cueq.backward() + loss_e3nn_back.backward() + + # Compare gradients for all conversions + def print_gradient_diff(name1, p1, name2, p2, conv_type): + if p1.grad is not None and p1.grad.shape == p2.grad.shape: + if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print( + f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + ) + torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) + + # E3nn to CuEq gradients + for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + model_e3nn.named_parameters(), model_cueq.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # CuEq to E3nn gradients + for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + model_cueq.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + ) + + # Full circle comparison (E3nn -> E3nn) + for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + ) + + # def test_jit_compile( + # self, + # model_config: Dict[str, Any], + # batch: Dict[str, torch.Tensor], + # device: str, + # ): + # torch.manual_seed(42) + + # # Create original E3nn model + # model_e3nn = modules.ScaleShiftMACE(**model_config) + # model_e3nn = model_e3nn.to(device) + + # # Convert E3nn to CuEq + # model_cueq = run_e3nn_to_cueq(model_e3nn) + # model_cueq = model_cueq.to(device) + + # # Convert CuEq back to E3nn + # model_e3nn_back = run_cueq_to_e3nn(model_cueq) + # model_e3nn_back = model_e3nn_back.to(device) + + # # # Compile all models + # model_e3nn_compiled = jit.compile(model_e3nn) + # model_cueq_compiled = jit.compile(model_cueq) + # model_e3nn_back_compiled = jit.compile(model_e3nn_back) + + # # Test forward pass equivalence + # out_e3nn = model_e3nn(batch, training=True) + # out_cueq = model_cueq(batch, training=True) + # out_e3nn_back = model_e3nn_back(batch, training=True) + + # out_e3nn_compiled = model_e3nn_compiled(batch, training=True) + # out_cueq_compiled = model_cueq_compiled(batch, training=True) + # out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True) + + # # Check outputs match for both conversions + # torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + # torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + # torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + # torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + + # torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"]) + # torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"]) + # torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"]) + # torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"]) + # torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"]) \ No newline at end of file diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca196c47..24d63863 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -847,3 +847,70 @@ def test_run_train_multihead_replay_custum_finetuning( assert len(Es) == len(fitting_configs) assert all(isinstance(E, float) for E in Es) assert len(set(Es)) > 1 # Ens + +def test_run_train_cueq(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["enable_cueq"] = True + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es)