From 38377f8911264f95c7ebbb4bf8f0442c36ad793f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 19 Sep 2024 08:58:41 +0100 Subject: [PATCH 01/22] test density normalization --- mace/modules/__init__.py | 4 + mace/modules/blocks.py | 192 +++++++++++++++++++++++++++++++++++++++ mace/tools/arg_parser.py | 4 + 3 files changed, 200 insertions(+) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..69e102b5 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -18,6 +18,8 @@ RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, ScaleShiftBlock, ) from .loss import ( @@ -56,6 +58,8 @@ "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 34539b0b..2bd20f74 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -664,6 +664,198 @@ def forward( ) # [n_nodes, channels, (lmax + 1)**2] +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + num_scalar_node_features = self.node_feats_irreps[0].mul + self.node_scalar_linear = torch.nn.Linear( + num_scalar_node_features, self.conv_tp.weight_numel + ) + + self.reshape = reshape_irreps(self.irreps_out) + self.density_fn = nn.FullyConnectedNet( + [self.conv_tp.weight_numel] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + node_feats_scalar = self.node_scalar_linear( + node_feats[:, self.node_feats_irreps.slices()[0]] + ) + edge_density = torch.tanh( + self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 + ) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + num_scalar_node_features = self.node_feats_irreps[0].mul + self.node_scalar_linear = torch.nn.Linear( + num_scalar_node_features, self.conv_tp.weight_numel + ) + + self.reshape = reshape_irreps(self.irreps_out) + self.density_fn = nn.FullyConnectedNet( + [self.conv_tp.weight_numel] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + node_feats_scalar = self.node_scalar_linear( + node_feats[:, self.node_feats_irreps.slices()[0]] + ) + edge_density = torch.tanh( + self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 + ) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 046f04d6..8fa8c0ac 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -153,6 +153,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticAttResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -163,6 +165,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=[ "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( From 294f90cfced3e9a518c1584ba867cc0ceb018cb3 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:23:56 +0100 Subject: [PATCH 02/22] simplify the density normalization --- mace/modules/blocks.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 2bd20f74..0db3b02e 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -710,19 +710,15 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Density normalization - num_scalar_node_features = self.node_feats_irreps[0].mul - self.node_scalar_linear = torch.nn.Linear( - num_scalar_node_features, self.conv_tp.weight_numel - ) - - self.reshape = reshape_irreps(self.irreps_out) self.density_fn = nn.FullyConnectedNet( - [self.conv_tp.weight_numel] + [input_dim] + [ 1, ], torch.nn.functional.silu, ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) def forward( self, @@ -737,12 +733,7 @@ def forward( num_nodes = node_feats.shape[0] node_feats = self.linear_up(node_feats) tp_weights = self.conv_tp_weights(edge_feats) - node_feats_scalar = self.node_scalar_linear( - node_feats[:, self.node_feats_irreps.slices()[0]] - ) - edge_density = torch.tanh( - self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 - ) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] @@ -806,20 +797,17 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Density normalization - num_scalar_node_features = self.node_feats_irreps[0].mul - self.node_scalar_linear = torch.nn.Linear( - num_scalar_node_features, self.conv_tp.weight_numel - ) - - self.reshape = reshape_irreps(self.irreps_out) self.density_fn = nn.FullyConnectedNet( - [self.conv_tp.weight_numel] + [input_dim] + [ 1, ], torch.nn.functional.silu, ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + def forward( self, node_attrs: torch.Tensor, @@ -834,12 +822,7 @@ def forward( sc = self.skip_tp(node_feats, node_attrs) node_feats = self.linear_up(node_feats) tp_weights = self.conv_tp_weights(edge_feats) - node_feats_scalar = self.node_scalar_linear( - node_feats[:, self.node_feats_irreps.slices()[0]] - ) - edge_density = torch.tanh( - self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 - ) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] From 1da124aefe4259b5fd9960ec394dc0e765456d16 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 10 Oct 2024 06:10:14 -0600 Subject: [PATCH 03/22] Fix compile_mode in MACECalculator --- mace/calculators/mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 9d307eda..dcd2b8e5 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -159,7 +159,7 @@ def __init__( mode=compile_mode, fullgraph=fullgraph, ) - for model in models + for model in self.models ] self.use_compile = True else: From 0f191264587988bc3ff0995c1d9ebd45b1978afb Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 10 Oct 2024 14:31:23 -0600 Subject: [PATCH 04/22] Fixing compile test cases --- mace/tools/compile.py | 2 +- setup.cfg | 1 + tests/test_compile.py | 9 +++++++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mace/tools/compile.py b/mace/tools/compile.py index 425e4c02..03282067 100644 --- a/mace/tools/compile.py +++ b/mace/tools/compile.py @@ -36,7 +36,7 @@ def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: """ if allow_autograd: dynamo.allow_in_graph(autograd.grad) - elif dynamo.allowed_functions.is_allowed(autograd.grad): + else: dynamo.disallow_in_graph(autograd.grad) @wraps(func) diff --git a/setup.cfg b/setup.cfg index 13d55161..6751b12d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,5 +52,6 @@ dev = mypy pre-commit pytest + pytest-benchmark pylint schedulefree = schedulefree \ No newline at end of file diff --git a/tests/test_compile.py b/tests/test_compile.py index 01106bef..d7d585e8 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -42,8 +42,10 @@ def create_mace(device: str, seed: int = 1702): "atomic_numbers": table.zs, "correlation": 3, "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, } - model = modules.MACE(**model_config) + model = modules.ScaleShiftMACE(**model_config) return model.to(device) @@ -122,11 +124,14 @@ def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 @pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) @pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) def test_compile_benchmark(benchmark, compile_mode, enable_amp): + if enable_amp: + pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") + with tools.torch_tools.default_dtype(torch.float32): batch = create_batch("cuda") torch.compiler.reset() model = mace_compile.prepare(create_mace)("cuda") - model = torch.compile(model, mode=compile_mode, fullgraph=True) + model = torch.compile(model, mode=compile_mode) model = time_func(model) with torch.autocast("cuda", enabled=enable_amp): From 7c9c9281fa678d1c59d7407fa900c11848d88dd2 Mon Sep 17 00:00:00 2001 From: Hubert Beck Date: Fri, 11 Oct 2024 11:52:06 +0100 Subject: [PATCH 05/22] fix reading in heads from model --- mace/cli/eval_configs.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index b5700bc4..79d886e0 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -80,10 +80,15 @@ def run(args: argparse.Namespace) -> None: z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + try: + heads = model.heads + except AttributeError: + heads = None + data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=z_table, cutoff=float(model.r_max) + config, z_table=z_table, cutoff=float(model.r_max), heads=heads ) for config in configs ], From 1b7e369e6e34ca8f5e0ecddff052698b188285a9 Mon Sep 17 00:00:00 2001 From: Hubert Beck Date: Fri, 11 Oct 2024 14:19:07 +0100 Subject: [PATCH 06/22] Add convenience argument --- mace/cli/eval_configs.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index 79d886e0..f44f7515 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -53,6 +53,13 @@ def parse_args() -> argparse.Namespace: type=str, default="MACE_", ) + parser.add_argument( + "--head", + help="Model head used for evaluation", + type=str, + required=False, + default=None + ) return parser.parse_args() @@ -76,6 +83,9 @@ def run(args: argparse.Namespace) -> None: # Load data and prepare input atoms_list = ase.io.read(args.configs, index=":") + if args.head is not None: + for atoms in atoms_list: + atoms.info["head"] = args.head configs = [data.config_from_atoms(atoms) for atoms in atoms_list] z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) From 377523f33f9e080ea364c6020e4ee58c9b693c13 Mon Sep 17 00:00:00 2001 From: Hubert Beck <71390574+beckobert@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:43:50 +0100 Subject: [PATCH 07/22] Fix echos => swa_start --- mace/tools/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/tools/checkpoint.py b/mace/tools/checkpoint.py index 8a62f1f2..ed673cdf 100644 --- a/mace/tools/checkpoint.py +++ b/mace/tools/checkpoint.py @@ -64,7 +64,7 @@ def __init__( self._filename_extension = "pt" def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs > swa_start: + if swa_start is not None and epochs => swa_start: return ( self.tag + self._epochs_string From 802669ab51068621c24b756a168969ac3dfb7a5a Mon Sep 17 00:00:00 2001 From: Hubert Beck <71390574+beckobert@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:24:34 +0100 Subject: [PATCH 08/22] typo How did I manage a typo when adding a single character? --- mace/tools/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/tools/checkpoint.py b/mace/tools/checkpoint.py index ed673cdf..81161ccc 100644 --- a/mace/tools/checkpoint.py +++ b/mace/tools/checkpoint.py @@ -64,7 +64,7 @@ def __init__( self._filename_extension = "pt" def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs => swa_start: + if swa_start is not None and epochs >= swa_start: return ( self.tag + self._epochs_string From 2386320c2b6a06f0567f6e29cec06fb87938f19b Mon Sep 17 00:00:00 2001 From: Thomas Warford Date: Tue, 22 Oct 2024 18:01:23 +0100 Subject: [PATCH 09/22] Turn json keys representing elements into ints --- mace/tools/scripts_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ac9d09fb..ec3d4637 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -327,6 +327,9 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: logging.info(f"Loading atomic energies from {E0s}") with open(E0s, "r", encoding="utf-8") as f: atomic_energies_dict = json.load(f) + atomic_energies_dict = { + int(key): value for key, value in atomic_energies_dict.items() + } else: try: atomic_energies_eval = ast.literal_eval(E0s) From ca759e14fe4e330a815efc950fd10a623e5bdae5 Mon Sep 17 00:00:00 2001 From: Thomas Warford Date: Wed, 23 Oct 2024 10:20:36 +0100 Subject: [PATCH 10/22] Test for reading E0s from json for multihead --- tests/test_run_train.py | 112 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca003317..ba6e2c7b 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -7,6 +7,7 @@ import numpy as np import pytest from ase.atoms import Atoms +import json from mace.calculators.mace import MACECalculator @@ -600,6 +601,117 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): assert np.allclose(Es, ref_Es, atol=1e-1) +def test_run_train_foundation_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w") as f: + json.dump(E0s, f) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", "E0s": f"{str(tmp_path)}/fit_multihead_dft.json"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + def test_run_train_multihead_replay_custum_finetuning( tmp_path, fitting_configs, pretraining_configs ): From 3cb7e962fc69c532cab6ae4d4e4e815bee2a7f81 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:53:03 +0100 Subject: [PATCH 11/22] Test passing model paths --- tests/test_foundations.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_foundations.py b/tests/test_foundations.py index fa35f8b9..03ea85c3 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest import torch @@ -13,6 +15,14 @@ from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.utils import AtomicNumberTable +MODEL_PATH = ( + Path(__file__).parent.parent + / "mace" + / "calculators" + / "foundations_models" + / "2023-12-03-mace-mp.model" +) + torch.set_default_dtype(torch.float64) config = data.Configuration( atomic_numbers=molecule("H2COH").numbers, @@ -172,9 +182,11 @@ def test_multi_reference(): mace_mp(model="small", device="cpu", default_dtype="float64").models[0], mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], mace_mp(model="large", device="cpu", default_dtype="float64").models[0], + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], mace_off(model="small", device="cpu", default_dtype="float64").models[0], mace_off(model="medium", device="cpu", default_dtype="float64").models[0], mace_off(model="large", device="cpu", default_dtype="float64").models[0], + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], ], ) def test_extract_config(model): From 90aeeca7526698e861a21b042e951abb5234bbf0 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:54:20 +0100 Subject: [PATCH 12/22] Fix model paths for mace_mp --- mace/calculators/foundations_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 5c9a896f..782ec38e 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -101,8 +101,15 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - model_path = download_mace_mp_checkpoint(model) - print(f"Using Materials Project MACE for MACECalculator with {model_path}") + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + model_path = model except Exception as exc: raise RuntimeError("Model download failed and no local model found") from exc From b0185fb4528695313f0502b19422b2fe7e51ba9a Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:01:17 +0100 Subject: [PATCH 13/22] Fix model paths for mace_off --- mace/calculators/foundations_models.py | 60 ++++++++++++++------------ 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 782ec38e..ed814f1a 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -180,36 +180,42 @@ def mace_off( MACECalculator: trained on the MACE-OFF23 dataset """ try: - urls = dict( - small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", - medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", - large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", - ) - checkpoint_url = ( - urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") - else model - ) - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - print(f"Downloading MACE model from {checkpoint_url!r}") - print( - "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + urls = dict( + small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", + medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", + large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", ) - print( - "ASL is based on the Gnu Public License, but does not permit commercial use" + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model ) - urllib.request.urlretrieve(checkpoint_url, cached_model_path) - print(f"Cached MACE model to {cached_model_path}") - model = cached_model_path - msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" - print(msg) + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + print(f"Downloading MACE model from {checkpoint_url!r}") + print( + "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + ) + print( + "ASL is based on the Gnu Public License, but does not permit commercial use" + ) + urllib.request.urlretrieve(checkpoint_url, cached_model_path) + print(f"Cached MACE model to {cached_model_path}") + model = cached_model_path + msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" + print(msg) + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") except Exception as exc: - raise RuntimeError("Model download failed") from exc + raise RuntimeError("Model download failed and no local model found") from exc device = device or ("cuda" if torch.cuda.is_available() else "cpu") From 8c59792b5efa7a5d54b81de2e4673b46cc30d0f9 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Mon, 28 Oct 2024 12:04:16 +0000 Subject: [PATCH 14/22] Increasing loss output digits --- mace/tools/train.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 8e293bee..3c6b8325 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -60,7 +60,7 @@ def valid_err_log( error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -70,7 +70,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress={error_stress:8.1f} meV / A^3", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -80,7 +80,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -90,7 +90,7 @@ def valid_err_log( error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -100,37 +100,37 @@ def valid_err_log( error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", ) From c1bb3b2604ceb135cd69c5dd567891ee488112e0 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:26:35 +0000 Subject: [PATCH 15/22] change the test name for json run_train mh --- tests/test_run_train.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ba6e2c7b..45f11c67 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -601,13 +601,13 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): assert np.allclose(Es, ref_Es, atol=1e-1) -def test_run_train_foundation_multihead(tmp_path, fitting_configs): +def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): fitting_configs_dft = [] fitting_configs_mp2 = [] for i, c in enumerate(fitting_configs): - + if i in (0, 1): - continue # skip isolated atoms, as energies specified by json files below + continue # skip isolated atoms, as energies specified by json files below elif i % 2 == 0: c.info["head"] = "DFT" fitting_configs_dft.append(c) @@ -625,8 +625,14 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): json.dump(E0s, f) heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", "E0s": f"{str(tmp_path)}/fit_multihead_dft.json"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json"}, + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, } yaml_str = "heads:\n" for key, value in heads.items(): From f6124f24c7a46c2781b2e6a2af16b69735fb9dc0 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 30 Oct 2024 09:33:08 -0400 Subject: [PATCH 16/22] Add mace_convert_dev cli tool to convert between devices --- mace/cli/convert_dev.py | 21 +++++++++++++++++++++ setup.cfg | 3 ++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 mace/cli/convert_dev.py diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_dev.py new file mode 100644 index 00000000..cc2ac1a6 --- /dev/null +++ b/mace/cli/convert_dev.py @@ -0,0 +1,21 @@ +from argparse import ArgumentParser +import torch + +def main(): + parser = ArgumentParser() + parser.add_argument("--target_device", "-t", + help="device to convert to, usually 'cpu' or 'cuda'", default="cpu") + parser.add_argument("--output_file", "-o", + help="name for output model, defaults to model_file.target_device") + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model.to(args.target_device) + torch.save(model, args.output_file) + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 6751b12d..139f914e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main + mace_convert_dev = mace.cli.convert_dev:main [options.extras_require] wandb = wandb @@ -54,4 +55,4 @@ dev = pytest pytest-benchmark pylint -schedulefree = schedulefree \ No newline at end of file +schedulefree = schedulefree From 165993c4db404644e574575fc69490f56fad473b Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:53:23 +0000 Subject: [PATCH 17/22] fixing import order --- mace/modules/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 69e102b5..e48e0b23 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -15,11 +15,11 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, - RealAgnosticDensityResidualInteractionBlock, - RealAgnosticDensityInteractionBlock, ScaleShiftBlock, ) from .loss import ( From 6a7901594f54fc3f5b2bea5088ef844b57368964 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:21:37 +0000 Subject: [PATCH 18/22] fix formatting --- mace/cli/convert_dev.py | 18 ++++++++++++++---- mace/cli/eval_configs.py | 4 ++-- tests/test_run_train.py | 8 ++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_dev.py index cc2ac1a6..9dd8c61d 100644 --- a/mace/cli/convert_dev.py +++ b/mace/cli/convert_dev.py @@ -1,12 +1,21 @@ from argparse import ArgumentParser + import torch + def main(): parser = ArgumentParser() - parser.add_argument("--target_device", "-t", - help="device to convert to, usually 'cpu' or 'cuda'", default="cpu") - parser.add_argument("--output_file", "-o", - help="name for output model, defaults to model_file.target_device") + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) parser.add_argument("model_file", help="input model file path") args = parser.parse_args() @@ -17,5 +26,6 @@ def main(): model.to(args.target_device) torch.save(model, args.output_file) + if __name__ == "__main__": main() diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index f44f7515..86470f8d 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -58,7 +58,7 @@ def parse_args() -> argparse.Namespace: help="Model head used for evaluation", type=str, required=False, - default=None + default=None, ) return parser.parse_args() @@ -94,7 +94,7 @@ def run(args: argparse.Namespace) -> None: heads = model.heads except AttributeError: heads = None - + data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 45f11c67..ca196c47 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -1,3 +1,4 @@ +import json import os import subprocess import sys @@ -7,7 +8,6 @@ import numpy as np import pytest from ase.atoms import Atoms -import json from mace.calculators.mace import MACECalculator @@ -608,7 +608,7 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): if i in (0, 1): continue # skip isolated atoms, as energies specified by json files below - elif i % 2 == 0: + if i % 2 == 0: c.info["head"] = "DFT" fitting_configs_dft.append(c) else: @@ -619,9 +619,9 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): # write E0s to json files E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w") as f: + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w") as f: + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: json.dump(E0s, f) heads = { From 787cda974ba799669e72a20d760074b9d07ad9d5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:52:57 +0000 Subject: [PATCH 19/22] fix multiple theory single atom case --- mace/cli/run_train.py | 1 - mace/tools/scripts_utils.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8cab392e..14509e4d 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -653,7 +653,6 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - print(test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ec3d4637..eb70b4d4 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -628,7 +628,7 @@ def custom_key(key): def dict_to_array(input_data, heads): if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) + return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) @@ -640,7 +640,7 @@ def dict_to_array(input_data, heads): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) result_array[head_index][key_index] = value - return np.squeeze(result_array) + return result_array class LRScheduler: From 74ecfdae7778a56a55f932700b4e5ce17665c0d5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:57:08 +0000 Subject: [PATCH 20/22] fix ndarray behavior dict_to_array --- mace/tools/scripts_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index eb70b4d4..3e2e1ed7 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -627,7 +627,9 @@ def custom_key(key): def dict_to_array(input_data, heads): - if not all(isinstance(value, dict) for value in input_data.values()): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) + elif not all(isinstance(value, dict) for value in input_data.values()): return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): From 32d2f97d9ce1805555700fcac0d1c63811b2ca1d Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:45:47 +0000 Subject: [PATCH 21/22] fix loading twice foundation E0s --- mace/cli/run_train.py | 18 +++++++++++++++--- mace/tools/scripts_utils.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 14509e4d..9b484d7f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -353,12 +353,18 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict[head_config.head_name] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -372,8 +378,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict["pt_head"] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 3e2e1ed7..d20e942b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -629,7 +629,7 @@ def custom_key(key): def dict_to_array(input_data, heads): if all(isinstance(value, np.ndarray) for value in input_data.values()): return np.array([input_data[head] for head in heads]) - elif not all(isinstance(value, dict) for value in input_data.values()): + if not all(isinstance(value, dict) for value in input_data.values()): return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): From 29c99ae084063230cbb032d43c3dda8b01447ddd Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:02:42 +0000 Subject: [PATCH 22/22] fix formatting --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 9b484d7f..7f1a5e74 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -364,7 +364,7 @@ def run(args: argparse.Namespace) -> None: z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: