Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a check for average E0s in the argument parsing #478

Open
wants to merge 1 commit into
base: multi-head-interface
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import json
import logging
import os
import urllib.request
from pathlib import Path
from typing import Optional
import urllib.request


import numpy as np
import torch.distributed
Expand All @@ -27,27 +26,28 @@
from mace.calculators.foundations_models import mace_mp, mace_off
from mace.cli.fine_tuning_select import select_samples
from mace.tools import torch_geometric
from mace.tools.finetuning_utils import (
extract_config_mace_model,
load_foundations_elements,
)
from mace.tools.scripts_utils import (
LRScheduler,
check_folder_subfolder,
create_error_table,
dict_to_array,
dict_to_namespace,
get_atomic_energies,
get_config_type_weights,
get_dataset_from_xyz,
get_files_with_suffix,
dict_to_array,
check_folder_subfolder,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.finetuning_utils import (
load_foundations_elements,
extract_config_mace_model,
)
from mace.tools.utils import AtomicNumberTable


def main() -> None:
args = tools.build_default_arg_parser().parse_args()
tools.check_args(args)
tag = tools.get_tag(name=args.name, seed=args.seed)

if args.device == "xpu":
Expand Down
10 changes: 7 additions & 3 deletions mace/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser
from .arg_parser import (
build_default_arg_parser,
build_preprocess_arg_parser,
check_args,
)
from .cg import U_matrix_real
from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState
from .finetuning_utils import extract_load, load_foundations_elements
from .torch_tools import (
TensorDict,
cartesian_to_spherical,
Expand Down Expand Up @@ -31,15 +36,14 @@
setup_logger,
)

from .finetuning_utils import load_foundations_elements, extract_load

__all__ = [
"TensorDict",
"AtomicNumberTable",
"atomic_numbers_to_indices",
"to_numpy",
"to_one_hot",
"build_default_arg_parser",
"check_args",
"set_seeds",
"init_device",
"setup_logger",
Expand Down
19 changes: 19 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"forces_weight",
],
)
parser.add_argument(
"--force",
help="Ignore checks for inconsistency in arguments. Should only be used for testing.",
action="store_true",
default=False,
)
return parser


Expand Down Expand Up @@ -784,3 +790,16 @@ def check_float_or_none(value: str) -> Optional[float]:
f"{value} is an invalid value (float or None)"
) from None
return None


def __check_e0s_and_finetuning(e0s: str, finetuning: bool) -> None:
if e0s == "average" and finetuning:
raise ValueError(
"Cannot use average E0s with finetuning, please provide E0s for each element."
)


def check_args(args: argparse.Namespace) -> None:
if args.force:
return
__check_e0s_and_finetuning(args.E0s, args.foundation_model is not None)
56 changes: 56 additions & 0 deletions tests/test_argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from mace.tools import build_default_arg_parser, check_args


def test_finetuning_with_e0s_average_raises_error():
parser = build_default_arg_parser()
args = parser.parse_args(
[
"--name",
"_",
"--train_file",
"_",
"--foundation_model",
"_",
"--E0s",
"average",
]
)
with pytest.raises(ValueError):
check_args(args)


def test_force_flag_skips_check():
parser = build_default_arg_parser()
args = parser.parse_args(
[
"--name",
"_",
"--train_file",
"_",
"--foundation_model",
"_",
"--E0s",
"average",
"--force",
]
)
check_args(args)


def test_finetuning_with_non_average_e0s_does_not_raise_error():
parser = build_default_arg_parser()
args = parser.parse_args(
[
"--name",
"_",
"--train_file",
"_",
"--foundation_model",
"_",
"--E0s",
"precomputes_e0s.json",
]
)
check_args(args)