Skip to content

Commit

Permalink
fix preprocessed test sets
Browse files Browse the repository at this point in the history
  • Loading branch information
beckobert committed Dec 20, 2024
1 parent 157c7b3 commit 321e0a6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def multi_valid_hdf5(process, args, split_valid, drop_last):

def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
args.h5_prefix + "test/" + name + "/" + "test_" + str(process) + ".h5", "w"
) as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)
Expand Down Expand Up @@ -268,6 +268,8 @@ def run(args: argparse.Namespace):
if args.test_file is not None:
logging.info("Preparing test sets")
for name, subset in collections.tests:
if not os.path.exists(args.h5_prefix + "test/" + name):
os.makedirs(args.h5_prefix + "test/" + name)
drop_last = False
if len(subset) % 2 == 1:
drop_last = True
Expand Down
2 changes: 1 addition & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def run(args: argparse.Namespace) -> None:
else:
test_folders = glob(head_config.test_dir + "/*")
for folder in test_folders:
name = os.path.splitext(os.path.basename(test_file))[0]
name = os.path.splitext(os.path.basename(folder))[0]
test_sets[name] = data.dataset_from_sharded_hdf5(
folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
)
Expand Down

0 comments on commit 321e0a6

Please sign in to comment.