Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add orb support... tricky #303

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class CorrelationKwargs(TypedDict, total=True):

# Janus specific
Architectures = Literal[
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet"
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "orb"
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"]
Expand Down
30 changes: 30 additions & 0 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,36 @@ def choose_calculator(
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model_path, device=device, **kwargs)

elif arch == "orb":
from orb_models import __version__
from orb_models.forcefield.calculator import ORBCalculator
from orb_models.forcefield.graph_regressor import GraphRegressor
import orb_models.forcefield.pretrained as orb_ff

if isinstance(model_path, str):
match model_path:
ElliottKasoar marked this conversation as resolved.
Show resolved Hide resolved
case "orb-v1":
model = orb_ff.orb_v1()
case "orb-mptraj-only-v1":
model = orb_ff.orb_v1_mptraj_only()
case "orb-d3-v1":
model = orb_ff.orb_d3_v1()
case "orb-d3-xs-v1":
model = orb_ff.orb_d3_xs_v1()
case "orb-d3-sm-v1":
model = orb_ff.orb_d3_sm_v1()
case _:
raise ValueError(
"Please specify `model_path`, as there is no "
f"default model for {arch}"
)
Comment on lines +235 to +250
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be split into multiple lines, but this is the gist

Suggested change
match model_path:
case "orb-v1":
model = orb_ff.orb_v1()
case "orb-mptraj-only-v1":
model = orb_ff.orb_v1_mptraj_only()
case "orb-d3-v1":
model = orb_ff.orb_d3_v1()
case "orb-d3-xs-v1":
model = orb_ff.orb_d3_xs_v1()
case "orb-d3-sm-v1":
model = orb_ff.orb_d3_sm_v1()
case _:
raise ValueError(
"Please specify `model_path`, as there is no "
f"default model for {arch}"
)
model = getattr(orb_ff, model_path.sub("-", "_"), None)()
if model is None:
raise ValueError(
"Please specify `model_path`, as there is no "
f"default model for {arch}"
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice but orb-mptraj does not match the pattern.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we choose what "model_path" is? You could also easily special case that. Here, it's not clear that that is different at first glance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no really I user the canonical names from their website... which i suspect is what people will expect to call them

elif isinstance(model_path, GraphRegressor):
model = model_path
else:
model = orb_ff.orb_v1_mptraj_only()

calculator = ORBCalculator(model=model, device=device, **kwargs)

else:
raise ValueError(
f"Unrecognized {arch=}. Suported architectures "
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ m3gnet = [
"matgl == 1.1.3",
"dgl == 2.1.0",
]
orb = [
"orb-models == 0.4.1",
"pynanoflann",
]
sevennet = [
"sevenn == 0.10.0",
]
all = [
"janus-core[alignn]",
"janus-core[chgnet]",
"janus-core[m3gnet]",
"janus-core[orb]",
"janus-core[sevennet]",
]

Expand Down Expand Up @@ -164,3 +169,6 @@ default-groups = [
"docs",
"pre-commit",
]

[tool.uv.sources]
pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" }
Loading