diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 5c9a896f..ed814f1a 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -101,8 +101,15 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - model_path = download_mace_mp_checkpoint(model) - print(f"Using Materials Project MACE for MACECalculator with {model_path}") + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + model_path = model except Exception as exc: raise RuntimeError("Model download failed and no local model found") from exc @@ -173,36 +180,42 @@ def mace_off( MACECalculator: trained on the MACE-OFF23 dataset """ try: - urls = dict( - small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", - medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", - large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", - ) - checkpoint_url = ( - urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") - else model - ) - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - print(f"Downloading MACE model from {checkpoint_url!r}") - print( - "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + urls = dict( + small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", + medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", + large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", ) - print( - "ASL is based on the Gnu Public License, but does not permit commercial use" + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model ) - urllib.request.urlretrieve(checkpoint_url, cached_model_path) - print(f"Cached MACE model to {cached_model_path}") - model = cached_model_path - msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" - print(msg) + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + print(f"Downloading MACE model from {checkpoint_url!r}") + print( + "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + ) + print( + "ASL is based on the Gnu Public License, but does not permit commercial use" + ) + urllib.request.urlretrieve(checkpoint_url, cached_model_path) + print(f"Cached MACE model to {cached_model_path}") + model = cached_model_path + msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" + print(msg) + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") except Exception as exc: - raise RuntimeError("Model download failed") from exc + raise RuntimeError("Model download failed and no local model found") from exc device = device or ("cuda" if torch.cuda.is_available() else "cpu") diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 9d307eda..dcd2b8e5 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -159,7 +159,7 @@ def __init__( mode=compile_mode, fullgraph=fullgraph, ) - for model in models + for model in self.models ] self.use_compile = True else: diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_dev.py new file mode 100644 index 00000000..9dd8c61d --- /dev/null +++ b/mace/cli/convert_dev.py @@ -0,0 +1,31 @@ +from argparse import ArgumentParser + +import torch + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model.to(args.target_device) + torch.save(model, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index b5700bc4..86470f8d 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -53,6 +53,13 @@ def parse_args() -> argparse.Namespace: type=str, default="MACE_", ) + parser.add_argument( + "--head", + help="Model head used for evaluation", + type=str, + required=False, + default=None, + ) return parser.parse_args() @@ -76,14 +83,22 @@ def run(args: argparse.Namespace) -> None: # Load data and prepare input atoms_list = ase.io.read(args.configs, index=":") + if args.head is not None: + for atoms in atoms_list: + atoms.info["head"] = args.head configs = [data.config_from_atoms(atoms) for atoms in atoms_list] z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + try: + heads = model.heads + except AttributeError: + heads = None + data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=z_table, cutoff=float(model.r_max) + config, z_table=z_table, cutoff=float(model.r_max), heads=heads ) for config in configs ], diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8cab392e..7f1a5e74 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -353,8 +353,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict[head_config.head_name] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs @@ -372,8 +378,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict["pt_head"] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs @@ -653,7 +665,6 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - print(test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..e48e0b23 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -15,6 +15,8 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, @@ -56,6 +58,8 @@ "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 34539b0b..0db3b02e 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -664,6 +664,181 @@ def forward( ) # [n_nodes, channels, (lmax + 1)**2] +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 11a6d2f3..70428c8d 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -153,6 +153,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticAttResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -163,6 +165,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=[ "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( diff --git a/mace/tools/checkpoint.py b/mace/tools/checkpoint.py index 8a62f1f2..81161ccc 100644 --- a/mace/tools/checkpoint.py +++ b/mace/tools/checkpoint.py @@ -64,7 +64,7 @@ def __init__( self._filename_extension = "pt" def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs > swa_start: + if swa_start is not None and epochs >= swa_start: return ( self.tag + self._epochs_string diff --git a/mace/tools/compile.py b/mace/tools/compile.py index 425e4c02..03282067 100644 --- a/mace/tools/compile.py +++ b/mace/tools/compile.py @@ -36,7 +36,7 @@ def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: """ if allow_autograd: dynamo.allow_in_graph(autograd.grad) - elif dynamo.allowed_functions.is_allowed(autograd.grad): + else: dynamo.disallow_in_graph(autograd.grad) @wraps(func) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ac9d09fb..d20e942b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -327,6 +327,9 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: logging.info(f"Loading atomic energies from {E0s}") with open(E0s, "r", encoding="utf-8") as f: atomic_energies_dict = json.load(f) + atomic_energies_dict = { + int(key): value for key, value in atomic_energies_dict.items() + } else: try: atomic_energies_eval = ast.literal_eval(E0s) @@ -624,8 +627,10 @@ def custom_key(key): def dict_to_array(input_data, heads): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) + return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) @@ -637,7 +642,7 @@ def dict_to_array(input_data, heads): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) result_array[head_index][key_index] = value - return np.squeeze(result_array) + return result_array class LRScheduler: diff --git a/mace/tools/train.py b/mace/tools/train.py index 8e293bee..3c6b8325 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -60,7 +60,7 @@ def valid_err_log( error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -70,7 +70,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress={error_stress:8.1f} meV / A^3", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -80,7 +80,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -90,7 +90,7 @@ def valid_err_log( error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -100,37 +100,37 @@ def valid_err_log( error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", ) diff --git a/setup.cfg b/setup.cfg index 13d55161..139f914e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main + mace_convert_dev = mace.cli.convert_dev:main [options.extras_require] wandb = wandb @@ -52,5 +53,6 @@ dev = mypy pre-commit pytest + pytest-benchmark pylint -schedulefree = schedulefree \ No newline at end of file +schedulefree = schedulefree diff --git a/tests/test_compile.py b/tests/test_compile.py index 01106bef..d7d585e8 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -42,8 +42,10 @@ def create_mace(device: str, seed: int = 1702): "atomic_numbers": table.zs, "correlation": 3, "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, } - model = modules.MACE(**model_config) + model = modules.ScaleShiftMACE(**model_config) return model.to(device) @@ -122,11 +124,14 @@ def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 @pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) @pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) def test_compile_benchmark(benchmark, compile_mode, enable_amp): + if enable_amp: + pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") + with tools.torch_tools.default_dtype(torch.float32): batch = create_batch("cuda") torch.compiler.reset() model = mace_compile.prepare(create_mace)("cuda") - model = torch.compile(model, mode=compile_mode, fullgraph=True) + model = torch.compile(model, mode=compile_mode) model = time_func(model) with torch.autocast("cuda", enabled=enable_amp): diff --git a/tests/test_foundations.py b/tests/test_foundations.py index fa35f8b9..03ea85c3 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest import torch @@ -13,6 +15,14 @@ from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.utils import AtomicNumberTable +MODEL_PATH = ( + Path(__file__).parent.parent + / "mace" + / "calculators" + / "foundations_models" + / "2023-12-03-mace-mp.model" +) + torch.set_default_dtype(torch.float64) config = data.Configuration( atomic_numbers=molecule("H2COH").numbers, @@ -172,9 +182,11 @@ def test_multi_reference(): mace_mp(model="small", device="cpu", default_dtype="float64").models[0], mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], mace_mp(model="large", device="cpu", default_dtype="float64").models[0], + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], mace_off(model="small", device="cpu", default_dtype="float64").models[0], mace_off(model="medium", device="cpu", default_dtype="float64").models[0], mace_off(model="large", device="cpu", default_dtype="float64").models[0], + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], ], ) def test_extract_config(model): diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca003317..ca196c47 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -1,3 +1,4 @@ +import json import os import subprocess import sys @@ -600,6 +601,123 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): assert np.allclose(Es, ref_Es, atol=1e-1) +def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + def test_run_train_multihead_replay_custum_finetuning( tmp_path, fitting_configs, pretraining_configs ):