From 32d2f97d9ce1805555700fcac0d1c63811b2ca1d Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:45:47 +0000 Subject: [PATCH] fix loading twice foundation E0s --- mace/cli/run_train.py | 18 +++++++++++++++--- mace/tools/scripts_utils.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 14509e4d..9b484d7f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -353,12 +353,18 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict[head_config.head_name] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -372,8 +378,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict["pt_head"] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 3e2e1ed7..d20e942b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -629,7 +629,7 @@ def custom_key(key): def dict_to_array(input_data, heads): if all(isinstance(value, np.ndarray) for value in input_data.values()): return np.array([input_data[head] for head in heads]) - elif not all(isinstance(value, dict) for value in input_data.values()): + if not all(isinstance(value, dict) for value in input_data.values()): return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values():