From 38377f8911264f95c7ebbb4bf8f0442c36ad793f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 19 Sep 2024 08:58:41 +0100 Subject: [PATCH 01/10] test density normalization --- mace/modules/__init__.py | 4 + mace/modules/blocks.py | 192 +++++++++++++++++++++++++++++++++++++++ mace/tools/arg_parser.py | 4 + 3 files changed, 200 insertions(+) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..69e102b5 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -18,6 +18,8 @@ RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, ScaleShiftBlock, ) from .loss import ( @@ -56,6 +58,8 @@ "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 34539b0b..2bd20f74 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -664,6 +664,198 @@ def forward( ) # [n_nodes, channels, (lmax + 1)**2] +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + num_scalar_node_features = self.node_feats_irreps[0].mul + self.node_scalar_linear = torch.nn.Linear( + num_scalar_node_features, self.conv_tp.weight_numel + ) + + self.reshape = reshape_irreps(self.irreps_out) + self.density_fn = nn.FullyConnectedNet( + [self.conv_tp.weight_numel] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + node_feats_scalar = self.node_scalar_linear( + node_feats[:, self.node_feats_irreps.slices()[0]] + ) + edge_density = torch.tanh( + self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 + ) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + num_scalar_node_features = self.node_feats_irreps[0].mul + self.node_scalar_linear = torch.nn.Linear( + num_scalar_node_features, self.conv_tp.weight_numel + ) + + self.reshape = reshape_irreps(self.irreps_out) + self.density_fn = nn.FullyConnectedNet( + [self.conv_tp.weight_numel] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + node_feats_scalar = self.node_scalar_linear( + node_feats[:, self.node_feats_irreps.slices()[0]] + ) + edge_density = torch.tanh( + self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 + ) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 046f04d6..8fa8c0ac 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -153,6 +153,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticAttResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -163,6 +165,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=[ "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( From 294f90cfced3e9a518c1584ba867cc0ceb018cb3 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:23:56 +0100 Subject: [PATCH 02/10] simplify the density normalization --- mace/modules/blocks.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 2bd20f74..0db3b02e 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -710,19 +710,15 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Density normalization - num_scalar_node_features = self.node_feats_irreps[0].mul - self.node_scalar_linear = torch.nn.Linear( - num_scalar_node_features, self.conv_tp.weight_numel - ) - - self.reshape = reshape_irreps(self.irreps_out) self.density_fn = nn.FullyConnectedNet( - [self.conv_tp.weight_numel] + [input_dim] + [ 1, ], torch.nn.functional.silu, ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) def forward( self, @@ -737,12 +733,7 @@ def forward( num_nodes = node_feats.shape[0] node_feats = self.linear_up(node_feats) tp_weights = self.conv_tp_weights(edge_feats) - node_feats_scalar = self.node_scalar_linear( - node_feats[:, self.node_feats_irreps.slices()[0]] - ) - edge_density = torch.tanh( - self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 - ) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] @@ -806,20 +797,17 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Density normalization - num_scalar_node_features = self.node_feats_irreps[0].mul - self.node_scalar_linear = torch.nn.Linear( - num_scalar_node_features, self.conv_tp.weight_numel - ) - - self.reshape = reshape_irreps(self.irreps_out) self.density_fn = nn.FullyConnectedNet( - [self.conv_tp.weight_numel] + [input_dim] + [ 1, ], torch.nn.functional.silu, ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + def forward( self, node_attrs: torch.Tensor, @@ -834,12 +822,7 @@ def forward( sc = self.skip_tp(node_feats, node_attrs) node_feats = self.linear_up(node_feats) tp_weights = self.conv_tp_weights(edge_feats) - node_feats_scalar = self.node_scalar_linear( - node_feats[:, self.node_feats_irreps.slices()[0]] - ) - edge_density = torch.tanh( - self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 - ) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] From f6124f24c7a46c2781b2e6a2af16b69735fb9dc0 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 30 Oct 2024 09:33:08 -0400 Subject: [PATCH 03/10] Add mace_convert_dev cli tool to convert between devices --- mace/cli/convert_dev.py | 21 +++++++++++++++++++++ setup.cfg | 3 ++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 mace/cli/convert_dev.py diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_dev.py new file mode 100644 index 00000000..cc2ac1a6 --- /dev/null +++ b/mace/cli/convert_dev.py @@ -0,0 +1,21 @@ +from argparse import ArgumentParser +import torch + +def main(): + parser = ArgumentParser() + parser.add_argument("--target_device", "-t", + help="device to convert to, usually 'cpu' or 'cuda'", default="cpu") + parser.add_argument("--output_file", "-o", + help="name for output model, defaults to model_file.target_device") + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model.to(args.target_device) + torch.save(model, args.output_file) + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 6751b12d..139f914e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main + mace_convert_dev = mace.cli.convert_dev:main [options.extras_require] wandb = wandb @@ -54,4 +55,4 @@ dev = pytest pytest-benchmark pylint -schedulefree = schedulefree \ No newline at end of file +schedulefree = schedulefree From 165993c4db404644e574575fc69490f56fad473b Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:53:23 +0000 Subject: [PATCH 04/10] fixing import order --- mace/modules/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 69e102b5..e48e0b23 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -15,11 +15,11 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, - RealAgnosticDensityResidualInteractionBlock, - RealAgnosticDensityInteractionBlock, ScaleShiftBlock, ) from .loss import ( From 6a7901594f54fc3f5b2bea5088ef844b57368964 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:21:37 +0000 Subject: [PATCH 05/10] fix formatting --- mace/cli/convert_dev.py | 18 ++++++++++++++---- mace/cli/eval_configs.py | 4 ++-- tests/test_run_train.py | 8 ++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_dev.py index cc2ac1a6..9dd8c61d 100644 --- a/mace/cli/convert_dev.py +++ b/mace/cli/convert_dev.py @@ -1,12 +1,21 @@ from argparse import ArgumentParser + import torch + def main(): parser = ArgumentParser() - parser.add_argument("--target_device", "-t", - help="device to convert to, usually 'cpu' or 'cuda'", default="cpu") - parser.add_argument("--output_file", "-o", - help="name for output model, defaults to model_file.target_device") + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) parser.add_argument("model_file", help="input model file path") args = parser.parse_args() @@ -17,5 +26,6 @@ def main(): model.to(args.target_device) torch.save(model, args.output_file) + if __name__ == "__main__": main() diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index f44f7515..86470f8d 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -58,7 +58,7 @@ def parse_args() -> argparse.Namespace: help="Model head used for evaluation", type=str, required=False, - default=None + default=None, ) return parser.parse_args() @@ -94,7 +94,7 @@ def run(args: argparse.Namespace) -> None: heads = model.heads except AttributeError: heads = None - + data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 45f11c67..ca196c47 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -1,3 +1,4 @@ +import json import os import subprocess import sys @@ -7,7 +8,6 @@ import numpy as np import pytest from ase.atoms import Atoms -import json from mace.calculators.mace import MACECalculator @@ -608,7 +608,7 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): if i in (0, 1): continue # skip isolated atoms, as energies specified by json files below - elif i % 2 == 0: + if i % 2 == 0: c.info["head"] = "DFT" fitting_configs_dft.append(c) else: @@ -619,9 +619,9 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): # write E0s to json files E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w") as f: + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w") as f: + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: json.dump(E0s, f) heads = { From 787cda974ba799669e72a20d760074b9d07ad9d5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:52:57 +0000 Subject: [PATCH 06/10] fix multiple theory single atom case --- mace/cli/run_train.py | 1 - mace/tools/scripts_utils.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8cab392e..14509e4d 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -653,7 +653,6 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - print(test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ec3d4637..eb70b4d4 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -628,7 +628,7 @@ def custom_key(key): def dict_to_array(input_data, heads): if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) + return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) @@ -640,7 +640,7 @@ def dict_to_array(input_data, heads): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) result_array[head_index][key_index] = value - return np.squeeze(result_array) + return result_array class LRScheduler: From 74ecfdae7778a56a55f932700b4e5ce17665c0d5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:57:08 +0000 Subject: [PATCH 07/10] fix ndarray behavior dict_to_array --- mace/tools/scripts_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index eb70b4d4..3e2e1ed7 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -627,7 +627,9 @@ def custom_key(key): def dict_to_array(input_data, heads): - if not all(isinstance(value, dict) for value in input_data.values()): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) + elif not all(isinstance(value, dict) for value in input_data.values()): return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): From 35c7de694b0fe74a89389d5b7b0b78e5fe274024 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Fri, 1 Nov 2024 22:23:56 +0000 Subject: [PATCH 08/10] Adding formatter_class to argparser --- mace/cli/active_learning_md.py | 4 +++- mace/cli/create_lammps_model.py | 4 +++- mace/cli/eval_configs.py | 4 +++- mace/cli/fine_tuning_select.py | 4 +++- mace/cli/plot_train.py | 5 ++++- mace/tools/arg_parser.py | 9 +++++++-- 6 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py index a26be698..9cf4f4a8 100644 --- a/mace/cli/active_learning_md.py +++ b/mace/cli/active_learning_md.py @@ -14,7 +14,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--config", help="path to XYZ configurations", required=True) parser.add_argument( "--config_index", help="index of configuration", type=int, default=-1 diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 1917ab8e..507a2cd0 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -7,7 +7,9 @@ def parse_args(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "model_path", type=str, diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index b5700bc4..7ea94012 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -16,7 +16,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--configs", help="path to XYZ configurations", required=True) parser.add_argument("--model", help="path to model", required=True) parser.add_argument("--output", help="output path", required=True) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 2fa5f644..94baf0dd 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -20,7 +20,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--configs_pt", help="path to XYZ configurations for the pretraining", diff --git a/mace/cli/plot_train.py b/mace/cli/plot_train.py index c249d76a..a1c424df 100644 --- a/mace/cli/plot_train.py +++ b/mace/cli/plot_train.py @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]: def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Plot mace training statistics") + parser = argparse.ArgumentParser( + description="Plot mace training statistics", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--path", help="path to results file or directory", required=True ) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 11a6d2f3..fddb3b72 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -15,6 +15,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser = configargparse.ArgumentParser( config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add( "--config", @@ -23,7 +24,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="config file to agregate options", ) except ImportError: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) # Name and seed parser.add_argument("--name", help="experiment name", required=True) @@ -700,7 +703,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--train_file", help="Training set h5 file", From 32d2f97d9ce1805555700fcac0d1c63811b2ca1d Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:45:47 +0000 Subject: [PATCH 09/10] fix loading twice foundation E0s --- mace/cli/run_train.py | 18 +++++++++++++++--- mace/tools/scripts_utils.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 14509e4d..9b484d7f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -353,12 +353,18 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict[head_config.head_name] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -372,8 +378,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict["pt_head"] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 3e2e1ed7..d20e942b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -629,7 +629,7 @@ def custom_key(key): def dict_to_array(input_data, heads): if all(isinstance(value, np.ndarray) for value in input_data.values()): return np.array([input_data[head] for head in heads]) - elif not all(isinstance(value, dict) for value in input_data.values()): + if not all(isinstance(value, dict) for value in input_data.values()): return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): From 29c99ae084063230cbb032d43c3dda8b01447ddd Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:02:42 +0000 Subject: [PATCH 10/10] fix formatting --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 9b484d7f..7f1a5e74 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -364,7 +364,7 @@ def run(args: argparse.Namespace) -> None: z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: