diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index 073f9223..0a4241ef 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -138,7 +138,7 @@ class CorrelationKwargs(TypedDict, total=True): # Janus specific Architectures = Literal[ - "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet" + "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "nequip" ] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"] diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index dd6c0354..648b0b34 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -217,6 +217,15 @@ def choose_calculator( kwargs.setdefault("sevennet_config", None) calculator = SevenNetCalculator(model=model, device=device, **kwargs) + elif arch == "nequip": + from nequip.ase import NequIPCalculator + + model = model_path if model_path else "" + + calculator = NequIPCalculator.from_deployed_model( + model_path=model, device=device, **kwargs + ) + else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " diff --git a/pyproject.toml b/pyproject.toml index d344e5b4..fb2a4fc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ sevenn = { version = "0.9.3", optional = true } torchdata = {version = "0.7.1", optional = true} # Pin due to dgl issue torch_geometric = { version = "^2.5.3", optional = true } ruff = "^0.5.7" +nequip = {version = "^0.6.1", optional = true } [tool.poetry.extras] all = ["alignn", "chgnet", "matgl", "dgl", "torchdata", "sevenn", "torch_geometric"] @@ -55,6 +56,7 @@ alignn = ["alignn"] chgnet = ["chgnet"] m3gnet = ["matgl", "dgl", "torchdata"] sevennet = ["sevenn", "torch_geometric"] +nequip = ["nequip"] [tool.poetry.group.dev.dependencies] coverage = {extras = ["toml"], version = "^7.4.1"}