Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into cleanup-zbl
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Dec 18, 2024
2 parents 4d8045e + c8f2d61 commit fc6fa95
Show file tree
Hide file tree
Showing 43 changed files with 2,877 additions and 417 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ dist/
*.xyz
/checkpoints
*.model

.benchmarks
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ repos:
'--disable=cell-var-from-loop',
'--disable=duplicate-code',
'--disable=use-dict-literal',
'--max-module-lines=1500',
]
exclude: *exclude_files
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[![License](https://img.shields.io/badge/License-MIT%202.0-blue.svg)](https://opensource.org/licenses/mit)
[![GitHub issues](https://img.shields.io/github/issues/ACEsuit/mace.svg)](https://GitHub.com/ACEsuit/mace/issues/)
[![Documentation Status](https://readthedocs.org/projects/mace/badge/)](https://mace-docs.readthedocs.io/en/latest/)
[![DOI](https://zenodo.org/badge/505964914.svg)](https://doi.org/10.5281/zenodo.14103332)

## Table of contents

Expand All @@ -19,6 +20,7 @@
- [Training](#training)
- [Evaluation](#evaluation)
- [Tutorials](#tutorials)
- [CUDA acceleration with cuEquivariance](#cuda-acceleration-with-cuequivariance)
- [Weights and Biases for experiment tracking](#weights-and-biases-for-experiment-tracking)
- [Pretrained Foundation Models](#pretrained-foundation-models)
- [MACE-MP: Materials Project Force Fields](#mace-mp-materials-project-force-fields)
Expand Down Expand Up @@ -170,6 +172,9 @@ We also have a more detailed Colab tutorials on:
- [Introduction to MACE active learning and fine-tuning](https://colab.research.google.com/drive/1oCSVfMhWrqHTeHbKgUSQN9hTKxLzoNyb)
- [MACE theory and code (advanced)](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU)

## CUDA acceleration with cuEquivariance

MACE supports CUDA acceleration with the cuEquivariance library. To install the library and use the acceleration, see our documentation at https://mace-docs.readthedocs.io/en/latest/guide/cuda_acceleration.html.

## On-line data loading for large datasets

Expand Down
2 changes: 1 addition & 1 deletion mace/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.3.7"
__version__ = "0.3.9"

__all__ = ["__version__"]
78 changes: 48 additions & 30 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
"small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model",
"medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
}

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2")
else model
)

Expand Down Expand Up @@ -101,8 +106,15 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
model_path = model
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

Expand Down Expand Up @@ -173,36 +185,42 @@ def mace_off(
MACECalculator: trained on the MACE-OFF23 dataset
"""
try:
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
except Exception as exc:
raise RuntimeError("Model download failed") from exc
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
37 changes: 29 additions & 8 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3

from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
Expand Down Expand Up @@ -60,10 +62,13 @@ def __init__(
model_type="MACE",
compile_mode=None,
fullgraph=True,
enable_cueq=False,
**kwargs,
):
Calculator.__init__(self, **kwargs)

if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None
if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
Expand Down Expand Up @@ -130,6 +135,12 @@ def __init__(
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models
]

elif models is not None:
if not isinstance(models, list):
Expand Down Expand Up @@ -159,7 +170,7 @@ def __init__(
mode=compile_mode,
fullgraph=fullgraph,
)
for model in models
for model in self.models
]
self.use_compile = True
else:
Expand Down Expand Up @@ -390,24 +401,34 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions)
if num_layers == -1:
num_layers = int(self.models[0].num_interactions)
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)

if invariants_only:
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_features = irreps_out.dim // (l_max + 1) ** 2
descriptors = [
extract_invariant(
descriptor,
num_layers=num_layers,
num_features=num_features,
num_features=num_invariant_features,
l_max=l_max,
)
for descriptor in descriptors
]
descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
]

if self.num_models == 1:
return descriptors[0]
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
Expand Down
Loading

0 comments on commit fc6fa95

Please sign in to comment.