Skip to content

Commit

Permalink
fix loading twice foundation E0s
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 4, 2024
1 parent 74ecfda commit 32d2f97
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
18 changes: 15 additions & 3 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,18 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict[head_config.head_name] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
}
}
else:
atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table)
else:
Expand All @@ -372,8 +378,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict["pt_head"] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand Down
2 changes: 1 addition & 1 deletion mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def custom_key(key):
def dict_to_array(input_data, heads):
if all(isinstance(value, np.ndarray) for value in input_data.values()):
return np.array([input_data[head] for head in heads])
elif not all(isinstance(value, dict) for value in input_data.values()):
if not all(isinstance(value, dict) for value in input_data.values()):
return np.array([[input_data[head]] for head in heads])
unique_keys = set()
for inner_dict in input_data.values():
Expand Down

0 comments on commit 32d2f97

Please sign in to comment.