Skip to content

Commit

Permalink
Merge branch 'main' into mace-tn
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Oct 30, 2024
2 parents 7712975 + 74dcd4c commit 42b9728
Show file tree
Hide file tree
Showing 16 changed files with 826 additions and 242 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ We are happy to accept pull requests under an [MIT license](https://choosealicen

If you use this code, please cite our papers:

```text
```bibtex
@inproceedings{Batatia2022mace,
title={{MACE}: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields},
author={Ilyes Batatia and David Peter Kovacs and Gregor N. C. Simm and Christoph Ortner and Gabor Csanyi},
Expand Down
157 changes: 88 additions & 69 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,51 @@
)


def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"""
Downloads or locates the MACE-MP checkpoint file.
Args:
model (str, optional): Path to the model or size specification.
Defaults to None which uses the medium model.
Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
return local_model_path

urls = {
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
}

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)

cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"

if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")

return cached_model_path


def mace_mp(
model: Union[str, Path] = None,
device: str = "",
Expand All @@ -23,6 +68,7 @@ def mace_mp(
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
dispersion_xc: str = "pbe",
dispersion_cutoff: float = 40.0 * units.Bohr,
return_raw_model: bool = False,
**kwargs,
) -> MACECalculator:
"""
Expand All @@ -42,59 +88,23 @@ def mace_mp(
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default model from figshare. Specify "small",
"medium" or "large" to download a smaller or larger model from figshare.
device (str, optional): Device to use for the model. Defaults to "cuda".
device (str, optional): Device to use for the model. Defaults to "cuda" if available.
default_dtype (str, optional): Default dtype for the model. Defaults to "float32".
dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False.
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bhor for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
Returns:
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
model = local_model_path
print(
f"Using local medium Materials Project MACE model for MACECalculator {model}"
)
elif model in (None, "small", "medium", "large") or str(model).startswith("https:"):
try:
# checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0
urls = dict(
small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model
medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model
large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(
checkpoint_url, cached_model_path
)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using Materials Project MACE for MACECalculator with {model}"
print(msg)
except Exception as exc:
raise RuntimeError(
"Model download failed and no local model found"
) from exc
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if default_dtype == "float64":
Expand All @@ -105,32 +115,36 @@ def mace_mp(
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)

if return_raw_model:
return torch.load(model_path, map_location=device)

mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
)
d3_calc = None
if dispersion:
gh_url = "https://github.com/pfnet-research/torch-dftd"
try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
f"Please install torch-dftd to use dispersion corrections (see {gh_url} from {exc})"
) from exc
print(
f"Using TorchDFTD3Calculator for D3 dispersion corrections (see {gh_url})"
)
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)
calc = mace_calc if not dispersion else SumCalculator([mace_calc, d3_calc])
return calc

if not dispersion:
return mace_calc

try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from exc

print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)

return SumCalculator([mace_calc, d3_calc])


def mace_off(
Expand Down Expand Up @@ -212,6 +226,7 @@ def mace_off(
def mace_anicc(
device: str = "cuda",
model_path: str = None,
return_raw_model: bool = False,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
Expand All @@ -227,4 +242,8 @@ def mace_anicc(
print(
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
)
return MACECalculator(model_path, device=device, default_dtype="float64")
if return_raw_model:
return torch.load(model_path, map_location=device)
return MACECalculator(
model_paths=model_path, device=device, default_dtype="float64"
)
97 changes: 67 additions & 30 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
###########################################################################################


import logging
from glob import glob
from pathlib import Path
from typing import Union
Expand All @@ -18,7 +19,7 @@
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
from mace.tools.scripts_utils import extract_load
from mace.tools.scripts_utils import extract_model


def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
Expand Down Expand Up @@ -49,8 +50,9 @@ class MACECalculator(Calculator):

def __init__(
self,
model_paths: Union[list, str],
device: str,
model_paths: Union[list, str, None] = None,
models: Union[list[torch.nn.Module], torch.nn.Module, None] = None,
device: str = "cpu",
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
default_dtype="",
Expand All @@ -61,6 +63,24 @@ def __init__(
**kwargs,
):
Calculator.__init__(self, **kwargs)

if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
)
if model_paths is None:
logging.warning(f"{deprecation_message} in the future.")
model_paths = kwargs["model_path"]
else:
raise ValueError(
f"both 'model_path' and 'model_paths' given, {deprecation_message} only."
)

if (model_paths is None) == (models is None):
raise ValueError(
"Exactly one of 'model_paths' or 'models' must be provided"
)

self.results = {}

self.model_type = model_type
Expand Down Expand Up @@ -89,53 +109,70 @@ def __init__(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
)

if "model_path" in kwargs:
print("model_path argument deprecated, use model_paths")
model_paths = kwargs["model_path"]

if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths)
if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}")
model_paths = model_paths_glob
elif isinstance(model_paths, Path):
model_paths = [model_paths]
if len(model_paths) == 0:
raise ValueError("No mace file names supplied")
self.num_models = len(model_paths)
if len(model_paths) > 1:
print(f"Running committee mace with {len(model_paths)} models")
if model_paths is not None:
if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths)

if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}")

model_paths = model_paths_glob
elif isinstance(model_paths, Path):
model_paths = [model_paths]

if len(model_paths) == 0:
raise ValueError("No mace file names supplied")
self.num_models = len(model_paths)

# Load models from files
self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]

elif models is not None:
if not isinstance(models, list):
models = [models]

if len(models) == 0:
raise ValueError("No models supplied")

self.models = models
self.num_models = len(models)

if self.num_models > 1:
print(f"Running committee mace with {self.num_models} models")

if model_type in ["MACE", "EnergyDipoleMACE"]:
self.implemented_properties.extend(
["energies", "energy_var", "forces_comm", "stress_var"]
)
elif model_type == "DipoleMACE":
self.implemented_properties.extend(["dipole_var"])

if compile_mode is not None:
print(f"Torch compile is enabled with mode: {compile_mode}")
self.models = [
torch.compile(
prepare(extract_load)(f=model_path, map_location=device),
prepare(extract_model)(model=model, map_location=device),
mode=compile_mode,
fullgraph=fullgraph,
)
for model_path in model_paths
for model in models
]
self.use_compile = True
else:
self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
self.use_compile = False

# Ensure all models are on the same device
for model in self.models:
model.to(device) # shouldn't be necessary but seems to help with GPU
model.to(device)

r_maxs = [model.r_max.cpu() for model in self.models]
r_maxs = np.array(r_maxs)
assert np.all(
r_maxs == r_maxs[0]
), "committee r_max are not all the same {' '.join(r_maxs)}"
if not np.all(r_maxs == r_maxs[0]):
raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}")
self.r_max = float(r_maxs[0])

self.device = torch_tools.init_device(device)
Expand Down
4 changes: 2 additions & 2 deletions mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def run(args: argparse.Namespace) -> None:
atoms_index = args.config_index

mace_calc = MACECalculator(
mace_fname,
args.device,
model_paths=mace_fname,
device=args.device,
default_dtype=args.default_dtype,
)

Expand Down
Loading

0 comments on commit 42b9728

Please sign in to comment.