From 157c7b3fb9fd40809f8cde4ec498ea8b5f85d999 Mon Sep 17 00:00:00 2001 From: Hubert Beck Date: Fri, 20 Dec 2024 16:22:48 +0000 Subject: [PATCH 1/3] Correct assignment of head --- mace/data/hdf5_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 477ccd3f..08ba1667 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -66,8 +66,7 @@ def __getitem__(self, index): pbc=unpack_value(subgrp["pbc"][()]), cell=unpack_value(subgrp["cell"][()]), ) - if config.head is None: - config.head = self.kwargs.get("head") + config.head = self.kwargs.get("head", "Default") atomic_data = AtomicData.from_config( config, z_table=self.z_table, From 321e0a6cd247cf93ee976c40a99759503c7e4aa4 Mon Sep 17 00:00:00 2001 From: Hubert Beck Date: Fri, 20 Dec 2024 18:20:09 +0000 Subject: [PATCH 2/3] fix preprocessed test sets --- mace/cli/preprocess_data.py | 4 +++- mace/cli/run_train.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index ef9f1343..a6ac7001 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -123,7 +123,7 @@ def multi_valid_hdf5(process, args, split_valid, drop_last): def multi_test_hdf5(process, name, args, split_test, drop_last): with h5py.File( - args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + args.h5_prefix + "test/" + name + "/" + "test_" + str(process) + ".h5", "w" ) as f: f.attrs["drop_last"] = drop_last save_configurations_as_HDF5(split_test[process], process, f) @@ -268,6 +268,8 @@ def run(args: argparse.Namespace): if args.test_file is not None: logging.info("Preparing test sets") for name, subset in collections.tests: + if not os.path.exists(args.h5_prefix + "test/" + name): + os.makedirs(args.h5_prefix + "test/" + name) drop_last = False if len(subset) % 2 == 1: drop_last = True diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1c0898b7..6f82ebc5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -706,7 +706,7 @@ def run(args: argparse.Namespace) -> None: else: test_folders = glob(head_config.test_dir + "/*") for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] + name = os.path.splitext(os.path.basename(folder))[0] test_sets[name] = data.dataset_from_sharded_hdf5( folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) From 15c22318ae1406282de27466dd368e18cf47de9a Mon Sep 17 00:00:00 2001 From: Hubert Beck Date: Fri, 20 Dec 2024 18:22:19 +0000 Subject: [PATCH 3/3] import glob correctly --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 6f82ebc5..661e9d55 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -6,11 +6,11 @@ import argparse import ast -import glob import json import logging import os from copy import deepcopy +from glob import glob from pathlib import Path from typing import List, Optional