diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index fa1933d9..fb3d4939 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -428,7 +428,7 @@ def check_files_exist(config: dict, req_file_keys: Sequence[PathLike]) -> None: If a key from `req_file_keys` is in the configuration file, but the file corresponding to the configuration value do not exist. """ - for file_key in req_file_keys: + for file_key in config.keys() & req_file_keys: # Only check if file key is in the configuration file - if file_key in config and not Path(config[file_key]).exists(): + if not Path(config[file_key]).exists(): raise FileNotFoundError(f"{config[file_key]} does not exist") diff --git a/janus_core/training/preprocess.py b/janus_core/training/preprocess.py index dcfc0be0..857042a9 100644 --- a/janus_core/training/preprocess.py +++ b/janus_core/training/preprocess.py @@ -16,7 +16,7 @@ def preprocess( mlip_config: PathLike, - req_file_keys: Sequence[PathLike] | None = None, + req_file_keys: Sequence[PathLike] = ("train_file", "test_file", "valid_file"), attach_logger: bool = False, log_kwargs: dict[str, Any] | None = None, track_carbon: bool = True, @@ -32,7 +32,7 @@ def preprocess( ---------- mlip_config : PathLike Configuration file to pass to MLIP. - req_file_keys : Sequence[PathLike] | None + req_file_keys : Sequence[PathLike] List of files that must exist if defined in the configuration file. Default is ("train_file", "test_file", "valid_file"). attach_logger : bool @@ -44,10 +44,7 @@ def preprocess( tracker_kwargs : dict[str, Any] | None Keyword arguments to pass to `config_tracker`. Default is {}. """ - (log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs)) - - if req_file_keys is None: - req_file_keys = ("train_file", "test_file", "valid_file") + log_kwargs, tracker_kwargs = none_to_dict((log_kwargs, tracker_kwargs)) # Validate inputs with open(mlip_config, encoding="utf8") as file: diff --git a/janus_core/training/train.py b/janus_core/training/train.py index 710c0a47..8cf0f44b 100644 --- a/janus_core/training/train.py +++ b/janus_core/training/train.py @@ -16,7 +16,12 @@ def train( mlip_config: PathLike, - req_file_keys: Sequence[PathLike] | None = None, + req_file_keys: Sequence[PathLike] = ( + "train_file", + "test_file", + "valid_file", + "statistics_file", + ), attach_logger: bool = False, log_kwargs: dict[str, Any] | None = None, track_carbon: bool = True, @@ -32,7 +37,7 @@ def train( ---------- mlip_config : PathLike Configuration file to pass to MLIP. - req_file_keys : Sequence[PathLike] | None + req_file_keys : Sequence[PathLike] List of files that must exist if defined in the configuration file. Default is ("train_file", "test_file", "valid_file", "statistics_file"). attach_logger : bool @@ -44,10 +49,7 @@ def train( tracker_kwargs : Optional[dict[str, Any]] Keyword arguments to pass to `config_tracker`. Default is {}. """ - (log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs)) - - if req_file_keys is None: - req_file_keys = ("train_file", "test_file", "valid_file", "statistics_file") + log_kwargs, tracker_kwargs = none_to_dict((log_kwargs, tracker_kwargs)) # Validate inputs with open(mlip_config, encoding="utf8") as file: