-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add preprocessing training data (#344)
* Tidy training * Add preprocessing * Refactor file check * Fix preprocess config * Add preprocess docs * Apply suggestions from code review Co-authored-by: Jacob Wilkins <[email protected]> --------- Co-authored-by: Jacob Wilkins <[email protected]>
- Loading branch information
1 parent
bdeaca9
commit ca1d799
Showing
12 changed files
with
525 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# noqa: I002, FA102 | ||
"""Set up MLIP preprocessing commandline interface.""" | ||
|
||
# Issues with future annotations and typer | ||
# c.f. https://github.com/maxb2/typer-config/issues/295 | ||
# from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import Annotated | ||
|
||
from typer import Option, Typer | ||
|
||
app = Typer() | ||
|
||
|
||
@app.command() | ||
def preprocess( | ||
mlip_config: Annotated[ | ||
Path, Option(help="Configuration file to pass to MLIP CLI.") | ||
], | ||
log: Annotated[Path, Option(help="Path to save logs to.")] = Path( | ||
"preprocess-log.yml" | ||
), | ||
tracker: Annotated[ | ||
bool, Option(help="Whether to save carbon emissions of calculation") | ||
] = True, | ||
summary: Annotated[ | ||
Path, | ||
Option( | ||
help=( | ||
"Path to save summary of inputs, start/end time, and carbon emissions." | ||
) | ||
), | ||
] = Path("preprocess-summary.yml"), | ||
): | ||
""" | ||
Convert training data to hdf5 by passing a configuration file to the MLIP's CLI. | ||
Parameters | ||
---------- | ||
mlip_config : Path | ||
Configuration file to pass to MLIP CLI. | ||
log : Optional[Path] | ||
Path to write logs to. Default is Path("preprocess-log.yml"). | ||
tracker : bool | ||
Whether to save carbon emissions of calculation in log file and summary. | ||
Default is True. | ||
summary : Optional[Path] | ||
Path to save summary of inputs, start/end time, and carbon emissions. Default | ||
is Path("preprocess-summary.yml"). | ||
""" | ||
from janus_core.cli.utils import carbon_summary, end_summary, start_summary | ||
from janus_core.training.preprocess import preprocess as run_preprocess | ||
|
||
inputs = {"mlip_config": str(mlip_config)} | ||
|
||
# Save summary information before preprocessing begins | ||
start_summary(command="preprocess", summary=summary, inputs=inputs) | ||
|
||
log_kwargs = {"filemode": "w"} | ||
if log: | ||
log_kwargs["filename"] = log | ||
|
||
# Run preprocessing | ||
run_preprocess( | ||
mlip_config, attach_logger=True, log_kwargs=log_kwargs, track_carbon=tracker | ||
) | ||
|
||
# Save carbon summary | ||
if tracker: | ||
carbon_summary(summary=summary, log=log) | ||
|
||
# Save time after preprocessing has finished | ||
end_summary(summary) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Preprocess MLIP training data.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from typing import Any | ||
|
||
from mace.cli.preprocess_data import run | ||
from mace.tools import build_preprocess_arg_parser as mace_parser | ||
import yaml | ||
|
||
from janus_core.helpers.janus_types import PathLike | ||
from janus_core.helpers.log import config_logger, config_tracker | ||
from janus_core.helpers.utils import check_files_exist, none_to_dict | ||
|
||
|
||
def preprocess( | ||
mlip_config: PathLike, | ||
req_file_keys: Sequence[PathLike] = ("train_file", "test_file", "valid_file"), | ||
attach_logger: bool = False, | ||
log_kwargs: dict[str, Any] | None = None, | ||
track_carbon: bool = True, | ||
tracker_kwargs: dict[str, Any] | None = None, | ||
) -> None: | ||
""" | ||
Convert training data to hdf5 by passing a configuration file to the MLIP's CLI. | ||
Currently only supports MACE models, but this can be extended by replacing the | ||
argument parsing. | ||
Parameters | ||
---------- | ||
mlip_config : PathLike | ||
Configuration file to pass to MLIP. | ||
req_file_keys : Sequence[PathLike] | ||
List of files that must exist if defined in the configuration file. | ||
Default is ("train_file", "test_file", "valid_file"). | ||
attach_logger : bool | ||
Whether to attach a logger. Default is False. | ||
log_kwargs : dict[str, Any] | None | ||
Keyword arguments to pass to `config_logger`. Default is {}. | ||
track_carbon : bool | ||
Whether to track carbon emissions of calculation. Default is True. | ||
tracker_kwargs : dict[str, Any] | None | ||
Keyword arguments to pass to `config_tracker`. Default is {}. | ||
""" | ||
log_kwargs, tracker_kwargs = none_to_dict(log_kwargs, tracker_kwargs) | ||
|
||
# Validate inputs | ||
with open(mlip_config, encoding="utf8") as file: | ||
options = yaml.safe_load(file) | ||
check_files_exist(options, req_file_keys) | ||
|
||
# Configure logging | ||
if attach_logger: | ||
log_kwargs.setdefault("filename", "preprocess-log.yml") | ||
log_kwargs.setdefault("name", __name__) | ||
logger = config_logger(**log_kwargs) | ||
tracker = config_tracker(logger, track_carbon, **tracker_kwargs) | ||
|
||
if logger and "foundation_model" in options: | ||
logger.info("Fine tuning model: %s", options["foundation_model"]) | ||
|
||
# Parse options from config, as MACE cannot read config file yet | ||
args = [] | ||
for key, value in options.items(): | ||
if isinstance(value, bool): | ||
if value is True: | ||
args.append(f"--{key}") | ||
else: | ||
args.append(f"--{key}") | ||
args.append(f"{value}") | ||
|
||
mlip_args = mace_parser().parse_args(args) | ||
|
||
if logger: | ||
logger.info("Starting preprocessing") | ||
if tracker: | ||
tracker.start_task("Preprocessing") | ||
|
||
run(mlip_args) | ||
|
||
if logger: | ||
logger.info("Preprocessing complete") | ||
if tracker: | ||
tracker.stop_task() | ||
tracker.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
train_file: "tests/data/mlip_train.xyz" | ||
valid_file: "tests/data/mlip_valid.xyz" | ||
test_file: "tests/data/mlip_test.xyz" | ||
energy_key: 'dft_energy' | ||
forces_key: 'dft_forces' | ||
stress_key: 'dft_stress' | ||
r_max: 4.0 | ||
scaling: 'rms_forces_scaling' | ||
batch_size: 4 | ||
seed: 2024 | ||
compute_statistics: False |
Oops, something went wrong.