diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 866ffeb5..6d9238ed 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -9,10 +9,9 @@ import json import logging import os +import urllib.request from pathlib import Path from typing import Optional -import urllib.request - import numpy as np import torch.distributed @@ -27,27 +26,28 @@ from mace.calculators.foundations_models import mace_mp, mace_off from mace.cli.fine_tuning_select import select_samples from mace.tools import torch_geometric +from mace.tools.finetuning_utils import ( + extract_config_mace_model, + load_foundations_elements, +) from mace.tools.scripts_utils import ( LRScheduler, + check_folder_subfolder, create_error_table, + dict_to_array, dict_to_namespace, get_atomic_energies, get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, - dict_to_array, - check_folder_subfolder, ) from mace.tools.slurm_distributed import DistributedEnvironment -from mace.tools.finetuning_utils import ( - load_foundations_elements, - extract_config_mace_model, -) from mace.tools.utils import AtomicNumberTable def main() -> None: args = tools.build_default_arg_parser().parse_args() + tools.check_args(args) tag = tools.get_tag(name=args.name, seed=args.seed) if args.device == "xpu": diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 5f851483..e63f4c41 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,6 +1,11 @@ -from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser import ( + build_default_arg_parser, + build_preprocess_arg_parser, + check_args, +) from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +from .finetuning_utils import extract_load, load_foundations_elements from .torch_tools import ( TensorDict, cartesian_to_spherical, @@ -31,8 +36,6 @@ setup_logger, ) -from .finetuning_utils import load_foundations_elements, extract_load - __all__ = [ "TensorDict", "AtomicNumberTable", @@ -40,6 +43,7 @@ "to_numpy", "to_one_hot", "build_default_arg_parser", + "check_args", "set_seeds", "init_device", "setup_logger", diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index abb550e1..5f91897f 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -636,6 +636,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "forces_weight", ], ) + parser.add_argument( + "--force", + help="Ignore checks for inconsistency in arguments. Should only be used for testing.", + action="store_true", + default=False, + ) return parser @@ -784,3 +790,16 @@ def check_float_or_none(value: str) -> Optional[float]: f"{value} is an invalid value (float or None)" ) from None return None + + +def __check_e0s_and_finetuning(e0s: str, finetuning: bool) -> None: + if e0s == "average" and finetuning: + raise ValueError( + "Cannot use average E0s with finetuning, please provide E0s for each element." + ) + + +def check_args(args: argparse.Namespace) -> None: + if args.force: + return + __check_e0s_and_finetuning(args.E0s, args.foundation_model is not None) diff --git a/tests/test_argparser.py b/tests/test_argparser.py new file mode 100644 index 00000000..13558c19 --- /dev/null +++ b/tests/test_argparser.py @@ -0,0 +1,56 @@ +import pytest + +from mace.tools import build_default_arg_parser, check_args + + +def test_finetuning_with_e0s_average_raises_error(): + parser = build_default_arg_parser() + args = parser.parse_args( + [ + "--name", + "_", + "--train_file", + "_", + "--foundation_model", + "_", + "--E0s", + "average", + ] + ) + with pytest.raises(ValueError): + check_args(args) + + +def test_force_flag_skips_check(): + parser = build_default_arg_parser() + args = parser.parse_args( + [ + "--name", + "_", + "--train_file", + "_", + "--foundation_model", + "_", + "--E0s", + "average", + "--force", + ] + ) + check_args(args) + + +def test_finetuning_with_non_average_e0s_does_not_raise_error(): + parser = build_default_arg_parser() + args = parser.parse_args( + [ + "--name", + "_", + "--train_file", + "_", + "--foundation_model", + "_", + "--E0s", + "precomputes_e0s.json", + ] + ) + check_args(args)