From 787cda974ba799669e72a20d760074b9d07ad9d5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:52:57 +0000 Subject: [PATCH] fix multiple theory single atom case --- mace/cli/run_train.py | 1 - mace/tools/scripts_utils.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8cab392e..14509e4d 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -653,7 +653,6 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - print(test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ec3d4637..eb70b4d4 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -628,7 +628,7 @@ def custom_key(key): def dict_to_array(input_data, heads): if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) + return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) @@ -640,7 +640,7 @@ def dict_to_array(input_data, heads): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) result_array[head_index][key_index] = value - return np.squeeze(result_array) + return result_array class LRScheduler: