diff --git a/README.md b/README.md index 77f8318f..8481760d 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![License](https://img.shields.io/badge/License-MIT%202.0-blue.svg)](https://opensource.org/licenses/mit) [![GitHub issues](https://img.shields.io/github/issues/ACEsuit/mace.svg)](https://GitHub.com/ACEsuit/mace/issues/) [![Documentation Status](https://readthedocs.org/projects/mace/badge/)](https://mace-docs.readthedocs.io/en/latest/) +[![DOI](https://zenodo.org/badge/505964914.svg)](https://doi.org/10.5281/zenodo.14103332) ## Table of contents diff --git a/mace/__version__.py b/mace/__version__.py index 47e8e016..2eb279ae 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1,3 +1,3 @@ -__version__ = "0.3.7" +__version__ = "0.3.8" __all__ = ["__version__"] diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 5c9a896f..ed814f1a 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -101,8 +101,15 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - model_path = download_mace_mp_checkpoint(model) - print(f"Using Materials Project MACE for MACECalculator with {model_path}") + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + model_path = model except Exception as exc: raise RuntimeError("Model download failed and no local model found") from exc @@ -173,36 +180,42 @@ def mace_off( MACECalculator: trained on the MACE-OFF23 dataset """ try: - urls = dict( - small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", - medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", - large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", - ) - checkpoint_url = ( - urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") - else model - ) - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - print(f"Downloading MACE model from {checkpoint_url!r}") - print( - "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + urls = dict( + small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", + medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", + large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", ) - print( - "ASL is based on the Gnu Public License, but does not permit commercial use" + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model ) - urllib.request.urlretrieve(checkpoint_url, cached_model_path) - print(f"Cached MACE model to {cached_model_path}") - model = cached_model_path - msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" - print(msg) + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + print(f"Downloading MACE model from {checkpoint_url!r}") + print( + "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + ) + print( + "ASL is based on the Gnu Public License, but does not permit commercial use" + ) + urllib.request.urlretrieve(checkpoint_url, cached_model_path) + print(f"Cached MACE model to {cached_model_path}") + model = cached_model_path + msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" + print(msg) + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") except Exception as exc: - raise RuntimeError("Model download failed") from exc + raise RuntimeError("Model download failed and no local model found") from exc device = device or ("cuda" if torch.cuda.is_available() else "cpu") diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 9d307eda..dcd2b8e5 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -159,7 +159,7 @@ def __init__( mode=compile_mode, fullgraph=fullgraph, ) - for model in models + for model in self.models ] self.use_compile = True else: diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py index a26be698..9cf4f4a8 100644 --- a/mace/cli/active_learning_md.py +++ b/mace/cli/active_learning_md.py @@ -14,7 +14,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--config", help="path to XYZ configurations", required=True) parser.add_argument( "--config_index", help="index of configuration", type=int, default=-1 diff --git a/mace/cli/convert_device.py b/mace/cli/convert_device.py new file mode 100644 index 00000000..9dd8c61d --- /dev/null +++ b/mace/cli/convert_device.py @@ -0,0 +1,31 @@ +from argparse import ArgumentParser + +import torch + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model.to(args.target_device) + torch.save(model, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 1917ab8e..507a2cd0 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -7,7 +7,9 @@ def parse_args(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "model_path", type=str, diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index b5700bc4..d00c54c6 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -16,7 +16,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--configs", help="path to XYZ configurations", required=True) parser.add_argument("--model", help="path to model", required=True) parser.add_argument("--output", help="output path", required=True) @@ -53,6 +55,13 @@ def parse_args() -> argparse.Namespace: type=str, default="MACE_", ) + parser.add_argument( + "--head", + help="Model head used for evaluation", + type=str, + required=False, + default=None, + ) return parser.parse_args() @@ -76,14 +85,22 @@ def run(args: argparse.Namespace) -> None: # Load data and prepare input atoms_list = ase.io.read(args.configs, index=":") + if args.head is not None: + for atoms in atoms_list: + atoms.info["head"] = args.head configs = [data.config_from_atoms(atoms) for atoms in atoms_list] z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + try: + heads = model.heads + except AttributeError: + heads = None + data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=z_table, cutoff=float(model.r_max) + config, z_table=z_table, cutoff=float(model.r_max), heads=heads ) for config in configs ], diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 2fa5f644..94baf0dd 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -20,7 +20,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--configs_pt", help="path to XYZ configurations for the pretraining", diff --git a/mace/cli/plot_train.py b/mace/cli/plot_train.py index c249d76a..a1c424df 100644 --- a/mace/cli/plot_train.py +++ b/mace/cli/plot_train.py @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]: def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Plot mace training statistics") + parser = argparse.ArgumentParser( + description="Plot mace training statistics", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--path", help="path to results file or directory", required=True ) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 560f91ee..00302cf1 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -36,7 +36,6 @@ LRScheduler, check_path_ase_read, convert_to_json_format, - create_error_table, dict_to_array, extract_config_mace_model, get_atomic_energies, @@ -49,9 +48,11 @@ get_params_options, get_swa, print_git_commit, + remove_pt_head, setup_wandb, ) from mace.tools.slurm_distributed import DistributedEnvironment +from mace.tools.tables_utils import create_error_table from mace.tools.utils import AtomicNumberTable @@ -115,10 +116,6 @@ def run(args: argparse.Namespace) -> None: commit = print_git_commit() model_foundation: Optional[torch.nn.Module] = None if args.foundation_model is not None: - if args.multiheads_finetuning: - assert ( - args.E0s != "average" - ), "average atomic energies cannot be used for multiheads finetuning" if args.foundation_model in ["small", "medium", "large"]: logging.info( f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." @@ -148,6 +145,27 @@ def run(args: argparse.Namespace) -> None: f"Using foundation model {args.foundation_model} as initial checkpoint." ) args.r_max = model_foundation.r_max.item() + if ( + args.foundation_model not in ["small", "medium", "large"] + and args.pt_train_file is None + ): + logging.warning( + "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." + ) + args.multiheads_finetuning = False + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" + # check that the foundation model has a single head, if not, use the first head + if hasattr(model_foundation, "heads"): + if len(model_foundation.heads) > 1: + logging.warning( + "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." + ) + model_foundation = remove_pt_head( + model_foundation, args.foundation_head + ) else: args.multiheads_finetuning = False @@ -353,8 +371,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict[head_config.head_name] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs @@ -372,8 +396,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict["pt_head"] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs @@ -575,7 +605,6 @@ def run(args: argparse.Namespace) -> None: distributed_model = DDP(model, device_ids=[local_rank]) else: distributed_model = None - tools.train( model=model, loss_fn=loss_fn, @@ -654,7 +683,6 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - print(test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/cli/select_head.py b/mace/cli/select_head.py new file mode 100644 index 00000000..a1e27229 --- /dev/null +++ b/mace/cli/select_head.py @@ -0,0 +1,33 @@ +from argparse import ArgumentParser + +import torch + +from mace.tools.scripts_utils import remove_pt_head + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--head_name", + "-n", + help="name of the head to extract", + default=None, + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model_single = remove_pt_head(model, args.head_name) + torch.save(model_single, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..e48e0b23 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -15,6 +15,8 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, @@ -56,6 +58,8 @@ "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 34539b0b..0db3b02e 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -664,6 +664,181 @@ def forward( ) # [n_nodes, channels, (lmax + 1)**2] +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index eecb0feb..1cb28fc3 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -15,6 +15,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser = configargparse.ArgumentParser( config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add( "--config", @@ -23,7 +24,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="config file to agregate options", ) except ImportError: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) # Name and seed parser.add_argument("--name", help="experiment name", required=True) @@ -153,6 +156,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticAttResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -163,6 +168,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=[ "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -353,6 +360,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str2bool, default=True, ) + parser.add_argument( + "--foundation_head", + help="Name of the head to use for fine-tuning", + type=str, + default=None, + required=False, + ) parser.add_argument( "--weight_pt_head", help="Weight of the pretrained head in the loss function", @@ -706,7 +720,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--train_file", help="Training set h5 file", diff --git a/mace/tools/checkpoint.py b/mace/tools/checkpoint.py index 8a62f1f2..81161ccc 100644 --- a/mace/tools/checkpoint.py +++ b/mace/tools/checkpoint.py @@ -64,7 +64,7 @@ def __init__( self._filename_extension = "pt" def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs > swa_start: + if swa_start is not None and epochs >= swa_start: return ( self.tag + self._epochs_string diff --git a/mace/tools/compile.py b/mace/tools/compile.py index 425e4c02..03282067 100644 --- a/mace/tools/compile.py +++ b/mace/tools/compile.py @@ -36,7 +36,7 @@ def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: """ if allow_autograd: dynamo.allow_in_graph(autograd.grad) - elif dynamo.allowed_functions.is_allowed(autograd.grad): + else: dynamo.disallow_in_graph(autograd.grad) @wraps(func) diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 0d4e2f52..8df0b0d1 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -73,10 +73,10 @@ def load_foundations_elements( model.interactions[i].linear.weight = torch.nn.Parameter( model_foundations.interactions[i].linear.weight.clone() ) - if ( - model.interactions[i].__class__.__name__ - == "RealAgnosticResidualInteractionBlock" - ): + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: model.interactions[i].skip_tp.weight = torch.nn.Parameter( model_foundations.interactions[i] .skip_tp.weight.reshape( @@ -101,7 +101,19 @@ def load_foundations_elements( .clone() / (num_species_foundations / num_species) ** 0.5 ) - + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + # Assuming only 1 layer in density_fn + getattr(model.interactions[i].density_fn, "layer0").weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].density_fn, + "layer0", + ).weight.clone() + ) + ) # Transferring products for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index 8e8c2877..3f49eb41 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -53,7 +53,6 @@ def configure_model( model_config_foundation["atomic_inter_shift"] = ( _determine_atomic_inter_shift(args.mean, heads) ) - model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] args.model = "FoundationMACE" diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ac9d09fb..be96558d 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -17,11 +17,9 @@ import torch import torch.distributed from e3nn import o3 -from prettytable import PrettyTable from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules, tools -from mace.tools import evaluate from mace.tools.train import SWAContainer @@ -224,6 +222,98 @@ def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: ) +def remove_pt_head( + model: torch.nn.Module, head_to_keep: Optional[str] = None +) -> torch.nn.Module: + """Converts a multihead MACE model to a single head model by removing the pretraining head. + + Args: + model (ScaleShiftMACE): The multihead MACE model to convert + head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. + + Returns: + ScaleShiftMACE: A new MACE model with only the specified head + + Raises: + ValueError: If the model is not a multihead model or if the specified head is not found + """ + if not hasattr(model, "heads") or len(model.heads) <= 1: + raise ValueError("Model must be a multihead model with more than one head") + + # Get index of head to keep + if head_to_keep is None: + # Find first non-PT head + try: + head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") + except StopIteration as e: + raise ValueError("No non-PT head found in model") from e + else: + try: + head_idx = model.heads.index(head_to_keep) + except ValueError as e: + raise ValueError(f"Head {head_to_keep} not found in model") from e + + # Extract config and modify for single head + model_config = extract_config_mace_model(model) + model_config["heads"] = [model.heads[head_idx]] + model_config["atomic_energies"] = ( + model.atomic_energies_fn.atomic_energies[head_idx] + .unsqueeze(0) + .detach() + .cpu() + .numpy() + ) + model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() + model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() + mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) // len(model.heads) + model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") + + new_model = model.__class__(**model_config) + state_dict = model.state_dict() + new_state_dict = {} + + for name, param in state_dict.items(): + if "atomic_energies" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "scale" in name or "shift" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "readouts" in name: + channels_per_head = param.shape[0] // len(model.heads) + start_idx = head_idx * channels_per_head + end_idx = start_idx + channels_per_head + if "linear_2.weight" in name: + end_idx = start_idx + channels_per_head // 2 + # if ( + # "readouts.0.linear.weight" in name + # or "readouts.1.linear_2.weight" in name + # ): + # new_state_dict[name] = param[start_idx:end_idx] / ( + # len(model.heads) ** 0.5 + # ) + if "readouts.0.linear.weight" in name: + new_state_dict[name] = param.reshape(-1, len(model.heads))[ + :, head_idx + ].flatten() + elif "readouts.1.linear_1.weight" in name: + new_state_dict[name] = param.reshape( + -1, len(model.heads), mlp_count_irreps + )[:, head_idx, :].flatten() + elif "readouts.1.linear_2.weight" in name: + new_state_dict[name] = param.reshape( + len(model.heads), -1, len(model.heads) + )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) + else: + new_state_dict[name] = param[start_idx:end_idx] + + else: + new_state_dict[name] = param + + # Load state dict into new model + new_model.load_state_dict(new_state_dict) + + return new_model + + def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: model_copy = model.__class__(**extract_config_mace_model(model)) model_copy.load_state_dict(model.state_dict()) @@ -327,6 +417,9 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: logging.info(f"Loading atomic energies from {E0s}") with open(E0s, "r", encoding="utf-8") as f: atomic_energies_dict = json.load(f) + atomic_energies_dict = { + int(key): value for key, value in atomic_energies_dict.items() + } else: try: atomic_energies_eval = ast.literal_eval(E0s) @@ -610,22 +703,11 @@ def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: ] -def custom_key(key): - """ - Helper function to sort the keys of the data loader dictionary - to ensure that the training set, and validation set - are evaluated first - """ - if key == "train": - return (0, key) - if key == "valid": - return (1, key) - return (2, key) - - def dict_to_array(input_data, heads): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) + return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) @@ -637,7 +719,7 @@ def dict_to_array(input_data, heads): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) result_array[head_index][key_index] = value - return np.squeeze(result_array) + return result_array class LRScheduler: @@ -675,227 +757,6 @@ def __getattr__(self, name): return getattr(self.lr_scheduler, name) -def create_error_table( - table_type: str, - all_data_loaders: dict, - model: torch.nn.Module, - loss_fn: torch.nn.Module, - output_args: Dict[str, bool], - log_wandb: bool, - device: str, - distributed: bool = False, -) -> PrettyTable: - if log_wandb: - import wandb - table = PrettyTable() - if table_type == "TotalRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSEstressvirials": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - "RMSE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "PerAtomMAEstressvirials": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - "MAE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "TotalMAE": - table.field_names = [ - "config_type", - "MAE E / meV", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "PerAtomMAE": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "DipoleRMSE": - table.field_names = [ - "config_type", - "RMSE MU / mDebye / atom", - "relative MU RMSE %", - ] - elif table_type == "DipoleMAE": - table.field_names = [ - "config_type", - "MAE MU / mDebye / atom", - "relative MU MAE %", - ] - elif table_type == "EnergyDipoleRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "rel F RMSE %", - "RMSE MU / mDebye / atom", - "rel MU RMSE %", - ] - - for name in sorted(all_data_loaders, key=custom_key): - data_loader = all_data_loaders[name] - logging.info(f"Evaluating {name} ...") - _, metrics = evaluate( - model, - loss_fn=loss_fn, - data_loader=data_loader, - output_args=output_args, - device=device, - ) - if distributed: - torch.distributed.barrier() - - del data_loader - torch.cuda.empty_cache() - if log_wandb: - wandb_log_dict = { - name - + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] - * 1e3, # meV / atom - name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A - name + "_final_rel_rmse_f": metrics["rel_rmse_f"], - } - wandb.log(wandb_log_dict) - if table_type == "TotalRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif table_type == "PerAtomRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_virials'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_virials'] * 1000:8.1f}", - ] - ) - elif table_type == "TotalMAE": - table.add_row( - [ - name, - f"{metrics['mae_e'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "PerAtomMAE": - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "DipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - elif table_type == "DipoleMAE": - table.add_row( - [ - name, - f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_mae_mu']:8.1f}", - ] - ) - elif table_type == "EnergyDipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - return table - - def check_folder_subfolder(folder_path): entries = os.listdir(folder_path) for entry in entries: diff --git a/mace/tools/tables_utils.py b/mace/tools/tables_utils.py new file mode 100644 index 00000000..07f41401 --- /dev/null +++ b/mace/tools/tables_utils.py @@ -0,0 +1,241 @@ +import logging +from typing import Dict + +import torch +from prettytable import PrettyTable + +from mace.tools import evaluate + + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == "train": + return (0, key) + if key == "valid": + return (1, key) + return (2, key) + + +def create_error_table( + table_type: str, + all_data_loaders: dict, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + output_args: Dict[str, bool], + log_wandb: bool, + device: str, + distributed: bool = False, +) -> PrettyTable: + if log_wandb: + import wandb + table = PrettyTable() + if table_type == "TotalRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSEstressvirials": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + "RMSE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "TotalMAE": + table.field_names = [ + "config_type", + "MAE E / meV", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "PerAtomMAE": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "DipoleRMSE": + table.field_names = [ + "config_type", + "RMSE MU / mDebye / atom", + "relative MU RMSE %", + ] + elif table_type == "DipoleMAE": + table.field_names = [ + "config_type", + "MAE MU / mDebye / atom", + "relative MU MAE %", + ] + elif table_type == "EnergyDipoleRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "rel F RMSE %", + "RMSE MU / mDebye / atom", + "rel MU RMSE %", + ] + + for name in sorted(all_data_loaders, key=custom_key): + data_loader = all_data_loaders[name] + logging.info(f"Evaluating {name} ...") + _, metrics = evaluate( + model, + loss_fn=loss_fn, + data_loader=data_loader, + output_args=output_args, + device=device, + ) + if distributed: + torch.distributed.barrier() + + del data_loader + torch.cuda.empty_cache() + if log_wandb: + wandb_log_dict = { + name + + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] + * 1e3, # meV / atom + name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A + name + "_final_rel_rmse_f": metrics["rel_rmse_f"], + } + wandb.log(wandb_log_dict) + if table_type == "TotalRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif table_type == "PerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", + ] + ) + elif table_type == "TotalMAE": + table.add_row( + [ + name, + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "PerAtomMAE": + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "DipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + elif table_type == "DipoleMAE": + table.add_row( + [ + name, + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", + ] + ) + elif table_type == "EnergyDipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + return table diff --git a/mace/tools/train.py b/mace/tools/train.py index 3c39415c..5e034401 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -61,7 +61,7 @@ def valid_err_log( error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -71,7 +71,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress={error_stress:8.1f} meV / A^3", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -81,7 +81,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -91,7 +91,7 @@ def valid_err_log( error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -101,37 +101,37 @@ def valid_err_log( error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", ) diff --git a/setup.cfg b/setup.cfg index c548140f..83401d52 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,8 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main + mace_convert_device = mace.cli.convert_device:main + mace_select_head = mace.cli.select_head:main [options.extras_require] wandb = wandb @@ -53,5 +55,6 @@ dev = mypy pre-commit pytest + pytest-benchmark pylint -schedulefree = schedulefree \ No newline at end of file +schedulefree = schedulefree diff --git a/tests/test_compile.py b/tests/test_compile.py index 01106bef..d7d585e8 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -42,8 +42,10 @@ def create_mace(device: str, seed: int = 1702): "atomic_numbers": table.zs, "correlation": 3, "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, } - model = modules.MACE(**model_config) + model = modules.ScaleShiftMACE(**model_config) return model.to(device) @@ -122,11 +124,14 @@ def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 @pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) @pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) def test_compile_benchmark(benchmark, compile_mode, enable_amp): + if enable_amp: + pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") + with tools.torch_tools.default_dtype(torch.float32): batch = create_batch("cuda") torch.compiler.reset() model = mace_compile.prepare(create_mace)("cuda") - model = torch.compile(model, mode=compile_mode, fullgraph=True) + model = torch.compile(model, mode=compile_mode) model = time_func(model) with torch.autocast("cuda", enabled=enable_amp): diff --git a/tests/test_foundations.py b/tests/test_foundations.py index fa35f8b9..44879395 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest import torch @@ -10,9 +12,17 @@ from mace.calculators import mace_mp, mace_off from mace.tools import torch_geometric from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.scripts_utils import extract_config_mace_model +from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head from mace.tools.utils import AtomicNumberTable +MODEL_PATH = ( + Path(__file__).parent.parent + / "mace" + / "calculators" + / "foundations_models" + / "2023-12-03-mace-mp.model" +) + torch.set_default_dtype(torch.float64) config = data.Configuration( atomic_numbers=molecule("H2COH").numbers, @@ -172,9 +182,11 @@ def test_multi_reference(): mace_mp(model="small", device="cpu", default_dtype="float64").models[0], mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], mace_mp(model="large", device="cpu", default_dtype="float64").models[0], + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], mace_off(model="small", device="cpu", default_dtype="float64").models[0], mace_off(model="medium", device="cpu", default_dtype="float64").models[0], mace_off(model="large", device="cpu", default_dtype="float64").models[0], + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], ], ) def test_extract_config(model): @@ -196,3 +208,240 @@ def test_extract_config(model): for key in output.keys(): if isinstance(output[key], torch.Tensor): assert torch.allclose(output[key], output_copy[key], atol=1e-5) + + +def test_remove_pt_head(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT"], + "atomic_inter_scale": [1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test molecule + mol = molecule("H2O") + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + energy=1.0, + forces=np.random.randn(len(mol), 3), + head="DFT", + ) + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + # Test original mode + output_orig = model(batch) + + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep="DFT") + + # Basic structure tests + assert len(new_model.heads) == 1 + assert new_model.heads[0] == "DFT" + assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + + # Test output consistency + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + output_new = new_model(batch) + torch.testing.assert_close( + output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 + ) + + +def test_remove_pt_head_multihead(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array( + [ + [1.0, 2.0], # H energies for each head + [3.0, 4.0], # O energies for each head + ] + * 2 + ) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT", "MP2", "CCSD"], + "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test configurations for each head + mol = molecule("H2O") + configs = {} + atomic_datas = {} + dataloaders = {} + original_outputs = {} + + # First get outputs from original model for each head + for head in model.heads: + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + energy=1.0, + forces=np.random.randn(len(mol), 3), + head=head, + ) + configs[head] = config_pt_head + + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads + ) + atomic_datas[head] = atomic_data + + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + dataloaders[head] = dataloader + + batch = next(iter(dataloader)) + output = model(batch) + original_outputs[head] = output + + # Now test each head separately + for i, head in enumerate(model.heads): + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep=head) + + # Basic structure tests + assert len(new_model.heads) == 1, f"Failed for head {head}" + assert new_model.heads[0] == head, f"Failed for head {head}" + assert ( + new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + ), f"Failed for head {head}" + + # Verify scale and shift values + assert torch.allclose( + new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] + ), f"Failed for head {head}" + assert torch.allclose( + new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] + ), f"Failed for head {head}" + + # Test output consistency + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + new_output = new_model(batch) + + # Compare outputs + print( + original_outputs[head]["energy"], + new_output["energy"], + ) + torch.testing.assert_close( + original_outputs[head]["energy"], + new_output["energy"], + rtol=1e-5, + atol=1e-5, + msg=f"Energy mismatch for head {head}", + ) + torch.testing.assert_close( + original_outputs[head]["forces"], + new_output["forces"], + rtol=1e-5, + atol=1e-5, + msg=f"Forces mismatch for head {head}", + ) + + # Test error cases + with pytest.raises(ValueError, match="Head non_existent not found in model"): + remove_pt_head(model, head_to_keep="non_existent") + + # Test default behavior (first non-PT head) + default_model = remove_pt_head(model) + assert default_model.heads[0] == "DFT" + + # Additional test: check if each model's computation graph is independent + models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} + results = {} + + for head, head_model in models.items(): + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + results[head] = head_model(batch) + + # Verify each model produces different outputs + energies = torch.stack([results[head]["energy"] for head in model.heads]) + assert not torch.allclose( + energies[0], energies[1], rtol=1e-3 + ), "Different heads should produce different outputs" diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca003317..ca196c47 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -1,3 +1,4 @@ +import json import os import subprocess import sys @@ -600,6 +601,123 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): assert np.allclose(Es, ref_Es, atol=1e-1) +def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + def test_run_train_multihead_replay_custum_finetuning( tmp_path, fitting_configs, pretraining_configs ):