diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index ef9f1343..a6ac7001 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -123,7 +123,7 @@ def multi_valid_hdf5(process, args, split_valid, drop_last): def multi_test_hdf5(process, name, args, split_test, drop_last): with h5py.File( - args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + args.h5_prefix + "test/" + name + "/" + "test_" + str(process) + ".h5", "w" ) as f: f.attrs["drop_last"] = drop_last save_configurations_as_HDF5(split_test[process], process, f) @@ -268,6 +268,8 @@ def run(args: argparse.Namespace): if args.test_file is not None: logging.info("Preparing test sets") for name, subset in collections.tests: + if not os.path.exists(args.h5_prefix + "test/" + name): + os.makedirs(args.h5_prefix + "test/" + name) drop_last = False if len(subset) % 2 == 1: drop_last = True diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1c0898b7..661e9d55 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -6,11 +6,11 @@ import argparse import ast -import glob import json import logging import os from copy import deepcopy +from glob import glob from pathlib import Path from typing import List, Optional @@ -706,7 +706,7 @@ def run(args: argparse.Namespace) -> None: else: test_folders = glob(head_config.test_dir + "/*") for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] + name = os.path.splitext(os.path.basename(folder))[0] test_sets[name] = data.dataset_from_sharded_hdf5( folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 477ccd3f..08ba1667 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -66,8 +66,7 @@ def __getitem__(self, index): pbc=unpack_value(subgrp["pbc"][()]), cell=unpack_value(subgrp["cell"][()]), ) - if config.head is None: - config.head = self.kwargs.get("head") + config.head = self.kwargs.get("head", "Default") atomic_data = AtomicData.from_config( config, z_table=self.z_table,