Skip to content

Commit

Permalink
Save model info
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Oct 22, 2024
1 parent e8ee44c commit d8db0b0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
36 changes: 19 additions & 17 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,20 @@ def choose_calculator(
from mace.calculators import mace_mp

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
model_path = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_mp(model=model, device=device, **kwargs)
calculator = mace_mp(model=model_path, device=device, **kwargs)

elif arch == "mace_off":
from mace import __version__
from mace.calculators import mace_off

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
model_path = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_off(model=model, device=device, **kwargs)
calculator = mace_off(model=model_path, device=device, **kwargs)

elif arch == "m3gnet":
from matgl import __version__, load_model
Expand All @@ -152,14 +152,16 @@ def choose_calculator(
# Otherwise, load the model if given a path, else use a default model
if isinstance(model_path, Potential):
potential = model_path
model_path = "loaded_Potential"
elif isinstance(model_path, Path):
if model_path.is_file():
model_path = model_path.parent
potential = load_model(model_path)
elif isinstance(model_path, str):
potential = load_model(model_path)
else:
potential = load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
model_path = "M3GNet-MP-2021.2.8-DIRECT-PES"
potential = load_model(model_path)

calculator = M3GNetCalculator(potential=potential, **kwargs)

Expand All @@ -176,11 +178,13 @@ def choose_calculator(
# Otherwise, load the model if given a path, else use a default model
if isinstance(model_path, CHGNet):
model = model_path
model_path = "loaded_CHGNet"
elif isinstance(model_path, Path):
model = CHGNet.from_file(model_path)
elif isinstance(model_path, str):
model = CHGNet.load(model_name=model_path, use_device=device)
else:
model_path = "0.3.0"
model = None

calculator = CHGNetCalculator(model=model, use_device=device, **kwargs)
Expand All @@ -195,32 +199,29 @@ def choose_calculator(

# Set default path to directory containing config and model location
if isinstance(model_path, Path):
path = model_path
if path.is_file():
path = path.parent
if model_path.is_file():
model_path = model_path.parent
# If a string, assume referring to model_name e.g. "v5.27.2024"
elif isinstance(model_path, str):
path = get_figshare_model_ff(model_name=model_path)
model_path = get_figshare_model_ff(model_name=model_path)
else:
path = default_path()
model_path = default_path()

calculator = AlignnAtomwiseCalculator(path=path, device=device, **kwargs)
calculator = AlignnAtomwiseCalculator(path=model_path, device=device, **kwargs)

elif arch == "sevennet":
# Disable constant-imported-as-non-constant
from sevenn._const import SEVENN_VERSION as __version__ # noqa: N811
from sevenn.sevennet_calculator import SevenNetCalculator

if isinstance(model_path, Path):
model = str(model_path)
elif isinstance(model_path, str):
model = model_path
else:
model = "SevenNet-0_11July2024"
model_path = str(model_path)
elif not isinstance(model_path, str):
model_path = "SevenNet-0_11July2024"

kwargs.setdefault("file_type", "checkpoint")
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model, device=device, **kwargs)
calculator = SevenNetCalculator(model=model_path, device=device, **kwargs)

else:
raise ValueError(
Expand All @@ -230,5 +231,6 @@ def choose_calculator(

calculator.parameters["version"] = __version__
calculator.parameters["arch"] = arch
calculator.parameters["model"] = str(model_path)

return calculator
4 changes: 4 additions & 0 deletions janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def results_to_info(
if struct.calc and "arch" in struct.calc.parameters:
arch = struct.calc.parameters["arch"]
struct.info["arch"] = arch
if struct.calc and "mlip_model" in struct.calc.parameters:
struct.info["mlip_model"] = struct.calc.parameters["model"]

for key in properties & struct.calc.results.keys():
tag = f"{arch}_{key}"
Expand Down Expand Up @@ -474,6 +476,8 @@ def output_structs(
for image in images:
if image.calc and "arch" in image.calc.parameters:
image.info["arch"] = image.calc.parameters["arch"]
if image.calc and "mlip_model" in image.calc.parameters:
image.info["mlip_model"] = image.calc.parameters["model"]

# Add label for system
for image in images:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_mlips(arch, device, kwargs):
"""Test mace calculators can be configured."""
calculator = choose_calculator(arch=arch, device=device, **kwargs)
assert calculator.parameters["version"] is not None
assert calculator.parameters["model"] is not None


def test_invalid_arch():
Expand Down Expand Up @@ -127,6 +128,7 @@ def test_extra_mlips(arch, device, kwargs):
**kwargs,
)
assert calculator.parameters["version"] is not None
assert calculator.parameters["model"] is not None
except BadZipFile:
pytest.skip()

Expand Down
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def test_output_structs(
else:
results_keys = {"energy", "forces", "stress"}

if arch == "mace_mp":
model = "small"
if arch == "m3gnet":
model = "M3GNet-MP-2021.2.8-DIRECT-PES"
if arch == "chgnet":
model = "0.3.0"

label_keys = {f"{arch}_{key}" for key in results_keys}

write_kwargs = {}
Expand Down Expand Up @@ -114,6 +121,7 @@ def test_output_structs(
if "set_info" not in write_kwargs or write_kwargs["set_info"]:
assert label_keys <= struct.info.keys() | struct.arrays.keys()
assert struct.info["arch"] == arch
assert struct.info["mlip_model"] == model

# Check file written correctly if write_results
if write_results:
Expand All @@ -125,6 +133,7 @@ def test_output_structs(
if "set_info" not in write_kwargs or write_kwargs["set_info"]:
assert label_keys <= atoms.info.keys() | atoms.arrays.keys()
assert atoms.info["arch"] == arch
assert atoms.info["mlip_model"] == model

# Check calculator results depend on invalidate_calc
if invalidate_calc:
Expand Down

0 comments on commit d8db0b0

Please sign in to comment.