Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
ElliottKasoar and oerc0122 committed Nov 7, 2024
1 parent 5059547 commit 019facc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
4 changes: 2 additions & 2 deletions janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def check_files_exist(config: dict, req_file_keys: Sequence[PathLike]) -> None:
If a key from `req_file_keys` is in the configuration file, but the
file corresponding to the configuration value do not exist.
"""
for file_key in req_file_keys:
for file_key in config.keys() & req_file_keys:
# Only check if file key is in the configuration file
if file_key in config and not Path(config[file_key]).exists():
if not Path(config[file_key]).exists():
raise FileNotFoundError(f"{config[file_key]} does not exist")
9 changes: 3 additions & 6 deletions janus_core/training/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def preprocess(
mlip_config: PathLike,
req_file_keys: Sequence[PathLike] | None = None,
req_file_keys: Sequence[PathLike] = ("train_file", "test_file", "valid_file"),
attach_logger: bool = False,
log_kwargs: dict[str, Any] | None = None,
track_carbon: bool = True,
Expand All @@ -32,7 +32,7 @@ def preprocess(
----------
mlip_config : PathLike
Configuration file to pass to MLIP.
req_file_keys : Sequence[PathLike] | None
req_file_keys : Sequence[PathLike]
List of files that must exist if defined in the configuration file.
Default is ("train_file", "test_file", "valid_file").
attach_logger : bool
Expand All @@ -44,10 +44,7 @@ def preprocess(
tracker_kwargs : dict[str, Any] | None
Keyword arguments to pass to `config_tracker`. Default is {}.
"""
(log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs))

if req_file_keys is None:
req_file_keys = ("train_file", "test_file", "valid_file")
log_kwargs, tracker_kwargs = none_to_dict((log_kwargs, tracker_kwargs))

# Validate inputs
with open(mlip_config, encoding="utf8") as file:
Expand Down
14 changes: 8 additions & 6 deletions janus_core/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

def train(
mlip_config: PathLike,
req_file_keys: Sequence[PathLike] | None = None,
req_file_keys: Sequence[PathLike] = (
"train_file",
"test_file",
"valid_file",
"statistics_file",
),
attach_logger: bool = False,
log_kwargs: dict[str, Any] | None = None,
track_carbon: bool = True,
Expand All @@ -32,7 +37,7 @@ def train(
----------
mlip_config : PathLike
Configuration file to pass to MLIP.
req_file_keys : Sequence[PathLike] | None
req_file_keys : Sequence[PathLike]
List of files that must exist if defined in the configuration file.
Default is ("train_file", "test_file", "valid_file", "statistics_file").
attach_logger : bool
Expand All @@ -44,10 +49,7 @@ def train(
tracker_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to `config_tracker`. Default is {}.
"""
(log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs))

if req_file_keys is None:
req_file_keys = ("train_file", "test_file", "valid_file", "statistics_file")
log_kwargs, tracker_kwargs = none_to_dict((log_kwargs, tracker_kwargs))

# Validate inputs
with open(mlip_config, encoding="utf8") as file:
Expand Down

0 comments on commit 019facc

Please sign in to comment.