Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save model info #336

Merged
merged 7 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,58 @@ Jupyter Notebook tutorials illustrating the use of currently available calculati
- [Phonons](https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/phonons.ipynb)


## Calculation outputs

By default, calculations performed will modify the underlying [ase.Atoms](https://wiki.fysik.dtu.dk/ase/ase/atoms.html) object
to store information in the `Atoms.info` and `Atoms.arrays` dictionaries about the MLIP used.

Additional dictionary keys include `arch`, corresponding to the MLIP architecture used,
and `model_path`, corresponding to the model path, name or label.

Results from the MLIP calculator, which are typically stored in `Atoms.calc.results`, will also, by default,
be copied to these dictionaries, prefixed by the MLIP `arch`.

For example:

```python
from janus_core.calculations.single_point import SinglePoint

single_point = SinglePoint(
struct_path="tests/data/NaCl.cif",
arch="mace_mp",
model_path="tests/models/mace_mp_small.model",
)

single_point.run()
print(single_point.struct.info)
```

will return

```python
{
'spacegroup': Spacegroup(1, setting=1),
'unit_cell': 'conventional',
'occupancy': {'0': {'Na': 1.0}, '1': {'Cl': 1.0}, '2': {'Na': 1.0}, '3': {'Cl': 1.0}, '4': {'Na': 1.0}, '5': {'Cl': 1.0}, '6': {'Na': 1.0}, '7': {'Cl': 1.0}},
'model_path': 'tests/models/mace_mp_small.model',
'arch': 'mace_mp',
'mace_mp_energy': -27.035127799332745,
'mace_mp_stress': array([-4.78327600e-03, -4.78327600e-03, -4.78327600e-03, 1.08000967e-19, -2.74004242e-19, -2.04504710e-19]),
'system_name': 'NaCl',
}
```

> [!NOTE]
> If running calculations with multiple MLIPs, `arch` and `mlip_model` will be overwritten with the most recent MLIP information.
> Results labelled by the architecture (e.g. `mace_mp_energy`) will be saved between MLIPs,
> unless the same `arch` is chosen, in which case these values will also be overwritten.

This is also the case the calculations performed using the CLI, with the same information written to extxyz output files.

> [!TIP]
> For complete provenance tracking, calculations and training can be run using the [aiida-mlip](https://github.com/stfc/aiida-mlip/) AiiDA plugin.


## Development

We recommend installing poetry for dependency management when developing for `janus-core`:
Expand Down
27 changes: 27 additions & 0 deletions docs/source/user_guide/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,33 @@ This will run a singlepoint energy calculation on ``KCl.cif`` using the `MACE-MP
Example configurations for all commands can be found in `janus-tutorials <https://github.com/stfc/janus-tutorials/tree/main/configs>`_


Output files
------------

By default, calculations performed will modify the underlying `ase.Atoms <https://wiki.fysik.dtu.dk/ase/ase/atoms.html>`_ object
to store information in the ``Atoms.info`` and ``Atoms.arrays`` dictionaries about the MLIP used.

Additional dictionary keys include ``arch``, corresponding to the MLIP architecture used,
and ``model_path``, corresponding to the model path, name or label.

Results from the MLIP calculator, which are typically stored in ``Atoms.calc.results``, will also, by default,
be copied to these dictionaries, prefixed by the MLIP ``arch``.

This information is then saved when extxyz files are written. For example:

.. code-block:: bash

janus singlepoint --struct tests/data/NaCl.cif --arch mace_mp --model-path /path/to/mace/model


Generates an output file, ``NaCl-results.extxyz``, with ``arch``, ``model_path``, ``mace_mp_energy``, ``mace_mp_forces``, and ``mace_mp_stress``.

.. note::
If running calculations with multiple MLIPs, ``arch`` and ``mlip_model`` will be overwritten with the most recent MLIP information.
Results labelled by the architecture (e.g. ``mace_mp_energy``) will be saved between MLIPs,
unless the same ``arch`` is chosen, in which case these values will also be overwritten.


Single point calculations
-------------------------

Expand Down
48 changes: 48 additions & 0 deletions docs/source/user_guide/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,51 @@ Jupyter Notebook tutorials illustrating the use of currently available calculati
- `Molecular Dynamics <https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/md.ipynb>`_
- `Equation of State <https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/eos.ipynb>`_
- `Phonons <https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/phonons.ipynb>`_


Calculation outputs
===================

By default, calculations performed will modify the underlying :class:`ase.Atoms` object
to store information in the ``Atoms.info`` and ``Atoms.arrays`` dictionaries about the MLIP used.

Additional dictionary keys include ``arch``, corresponding to the MLIP architecture used,
and ``model_path``, corresponding to the model path, name or label.

Results from the MLIP calculator, which are typically stored in ``Atoms.calc.results``, will also,
by default, be copied to these dictionaries, prefixed by the MLIP ``arch``.

For example:

.. code-block:: python

from janus_core.calculations.single_point import SinglePoint

single_point = SinglePoint(
struct_path="tests/data/NaCl.cif",
arch="mace_mp",
model_path="tests/models/mace_mp_small.model",
)

single_point.run()
print(single_point.struct.info)

will return

.. code-block:: python

{
'spacegroup': Spacegroup(1, setting=1),
'unit_cell': 'conventional',
'occupancy': {'0': {'Na': 1.0}, '1': {'Cl': 1.0}, '2': {'Na': 1.0}, '3': {'Cl': 1.0}, '4': {'Na': 1.0}, '5': {'Cl': 1.0}, '6': {'Na': 1.0}, '7': {'Cl': 1.0}},
'model_path': 'tests/models/mace_mp_small.model',
'arch': 'mace_mp',
'mace_mp_energy': -27.035127799332745,
'mace_mp_stress': array([-4.78327600e-03, -4.78327600e-03, -4.78327600e-03, 1.08000967e-19, -2.74004242e-19, -2.04504710e-19]),
'system_name': 'NaCl',
}

.. note::
If running calculations with multiple MLIPs, ``arch`` and ``mlip_model`` will be overwritten with the most recent MLIP information.
Results labelled by the architecture (e.g. ``mace_mp_energy``) will be saved between MLIPs,
unless the same ``arch`` is chosen, in which case these values will also be overwritten.
36 changes: 19 additions & 17 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,20 @@ def choose_calculator(
from mace.calculators import mace_mp

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
model_path = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_mp(model=model, device=device, **kwargs)
calculator = mace_mp(model=model_path, device=device, **kwargs)

elif arch == "mace_off":
from mace import __version__
from mace.calculators import mace_off

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
model_path = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_off(model=model, device=device, **kwargs)
calculator = mace_off(model=model_path, device=device, **kwargs)

elif arch == "m3gnet":
from matgl import __version__, load_model
Expand All @@ -154,14 +154,16 @@ def choose_calculator(
# Otherwise, load the model if given a path, else use a default model
if isinstance(model_path, Potential):
potential = model_path
model_path = "loaded_Potential"
elif isinstance(model_path, Path):
if model_path.is_file():
model_path = model_path.parent
potential = load_model(model_path)
elif isinstance(model_path, str):
potential = load_model(model_path)
else:
potential = load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
model_path = "M3GNet-MP-2021.2.8-DIRECT-PES"
potential = load_model(model_path)

calculator = M3GNetCalculator(potential=potential, **kwargs)

Expand All @@ -178,11 +180,13 @@ def choose_calculator(
# Otherwise, load the model if given a path, else use a default model
if isinstance(model_path, CHGNet):
model = model_path
model_path = "loaded_CHGNet"
elif isinstance(model_path, Path):
model = CHGNet.from_file(model_path)
elif isinstance(model_path, str):
model = CHGNet.load(model_name=model_path, use_device=device)
else:
model_path = "0.3.0"
model = None

calculator = CHGNetCalculator(model=model, use_device=device, **kwargs)
Expand All @@ -197,31 +201,28 @@ def choose_calculator(

# Set default path to directory containing config and model location
if isinstance(model_path, Path):
path = model_path
if path.is_file():
path = path.parent
if model_path.is_file():
model_path = model_path.parent
# If a string, assume referring to model_name e.g. "v5.27.2024"
elif isinstance(model_path, str):
path = get_figshare_model_ff(model_name=model_path)
model_path = get_figshare_model_ff(model_name=model_path)
else:
path = default_path()
model_path = default_path()

calculator = AlignnAtomwiseCalculator(path=path, device=device, **kwargs)
calculator = AlignnAtomwiseCalculator(path=model_path, device=device, **kwargs)

elif arch == "sevennet":
from sevenn import __version__
from sevenn.sevennet_calculator import SevenNetCalculator

if isinstance(model_path, Path):
model = str(model_path)
elif isinstance(model_path, str):
model = model_path
else:
model = "SevenNet-0_11July2024"
model_path = str(model_path)
elif not isinstance(model_path, str):
model_path = "SevenNet-0_11July2024"

kwargs.setdefault("file_type", "checkpoint")
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model, device=device, **kwargs)
calculator = SevenNetCalculator(model=model_path, device=device, **kwargs)

else:
raise ValueError(
Expand All @@ -231,6 +232,7 @@ def choose_calculator(

calculator.parameters["version"] = __version__
calculator.parameters["arch"] = arch
calculator.parameters["model_path"] = str(model_path)

return calculator

Expand Down
5 changes: 5 additions & 0 deletions janus_core/helpers/struct_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def results_to_info(
if not properties:
properties = get_args(Properties)

if struct.calc and "model_path" in struct.calc.parameters:
struct.info["model_path"] = struct.calc.parameters["model_path"]

# Only add to info if MLIP calculator with "arch" parameter set
if struct.calc and "arch" in struct.calc.parameters:
arch = struct.calc.parameters["arch"]
Expand Down Expand Up @@ -265,6 +268,8 @@ def output_structs(
for image in images:
if image.calc and "arch" in image.calc.parameters:
image.info["arch"] = image.calc.parameters["arch"]
if image.calc and "model_path" in image.calc.parameters:
image.info["model_path"] = image.calc.parameters["model_path"]

# Add label for system
for image in images:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_mlips(arch, device, kwargs):
"""Test mace calculators can be configured."""
calculator = choose_calculator(arch=arch, device=device, **kwargs)
assert calculator.parameters["version"] is not None
assert calculator.parameters["model_path"] is not None


def test_invalid_arch():
Expand Down Expand Up @@ -127,6 +128,7 @@ def test_extra_mlips(arch, device, kwargs):
**kwargs,
)
assert calculator.parameters["version"] is not None
assert calculator.parameters["model_path"] is not None
except BadZipFile:
pytest.skip()

Expand Down
14 changes: 10 additions & 4 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,23 @@ def test_singlepoint():
assert summary_path.exists

finally:
# Check atoms can read read, then delete file
# Ensure files deleted if command fails
log_path.unlink(missing_ok=True)
summary_path.unlink(missing_ok=True)

# Check atoms file can be read, then delete
atoms = read_atoms(results_path)
assert "mace_mp_energy" in atoms.info

assert "arch" in atoms.info
assert "model_path" in atoms.info
assert atoms.info["arch"] == "mace_mp"
assert atoms.info["model_path"] == "small"

assert "mace_mp_forces" in atoms.arrays
assert "system_name" in atoms.info
assert atoms.info["system_name"] == "NaCl"

# Ensure files deleted if command fails
log_path.unlink(missing_ok=True)
summary_path.unlink(missing_ok=True)
clear_log_handlers()


Expand Down
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def test_output_structs(
else:
results_keys = {"energy", "forces", "stress"}

if arch == "mace_mp":
model = "small"
if arch == "m3gnet":
model = "M3GNet-MP-2021.2.8-DIRECT-PES"
if arch == "chgnet":
model = "0.3.0"

label_keys = {f"{arch}_{key}" for key in results_keys}

write_kwargs = {}
Expand Down Expand Up @@ -115,6 +122,7 @@ def test_output_structs(
if "set_info" not in write_kwargs or write_kwargs["set_info"]:
assert label_keys <= struct.info.keys() | struct.arrays.keys()
assert struct.info["arch"] == arch
assert struct.info["model_path"] == model

# Check file written correctly if write_results
if write_results:
Expand All @@ -126,6 +134,7 @@ def test_output_structs(
if "set_info" not in write_kwargs or write_kwargs["set_info"]:
assert label_keys <= atoms.info.keys() | atoms.arrays.keys()
assert atoms.info["arch"] == arch
assert atoms.info["model_path"] == model

# Check calculator results depend on invalidate_calc
if invalidate_calc:
Expand Down