Skip to content

Commit

Permalink
allow custom cache based on XDG_CACHE_HOME env variable
Browse files Browse the repository at this point in the history
  • Loading branch information
alinelena committed Dec 17, 2024
1 parent c8f2d61 commit fa15322
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- [MACE-OFF: Transferable Organic Force Fields](#mace-off-transferable-organic-force-fields)
- [Example usage in ASE](#example-usage-in-ase-1)
- [Finetuning foundation models](#finetuning-foundation-models)
- [Caching](#caching)
- [Development](#development)
- [References](#references)
- [Contact](#contact)
Expand Down Expand Up @@ -59,7 +60,7 @@ A partial documentation is available at: https://mace-docs.readthedocs.io
**Make sure to install PyTorch.** Please refer to the [official PyTorch installation](https://pytorch.org/get-started/locally/) for the installation instructions. Select the appropriate options for your system.

### Installation from PyPI
This is the recommended way to install MACE.
This is the recommended way to install MACE.

```sh
pip install --upgrade pip
Expand Down Expand Up @@ -109,7 +110,7 @@ mace_run_train \

To give a specific validation set, use the argument `--valid_file`. To set a larger batch size for evaluating the validation set, specify `--valid_batch_size`.

To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys.
To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys.

It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression.

Expand Down Expand Up @@ -295,11 +296,16 @@ mace_run_train \
--amsgrad \
--default_dtype="float32" \
--device=cuda \
--seed=3
--seed=3
```
Other options are "medium" and "large", or the path to a foundation model.
Other options are "medium" and "large", or the path to a foundation model.
If you want to finetune another model, the model will be loaded from the path provided `--foundation_model=$path_model`, but you will need to provide the full set of hyperparameters (hidden irreps, r_max, etc.) matching the model.

## Caching

By default automatically downloaded models, like mace_mp, mace_off and data for fine tuning, end up in `~/.cache/mace`. The path can be changed by using
the environment variable XDG_CACHE_HOME. When set, the new cache path expands to $XDG_CACHE_HOME/.cache/mace

## Development

This project uses [pre-commit](https://pre-commit.com/) to execute code formatting and linting on commit.
Expand Down
36 changes: 30 additions & 6 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,24 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2")
if model
in (
None,
"small",
"medium",
"large",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"large-0b2",
)
else model
)

cache_dir = os.path.expanduser("~/.cache/mace")
cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace"
)
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
Expand Down Expand Up @@ -106,9 +119,17 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith(
"https:"
):
if model in (
None,
"small",
"medium",
"large",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"large-0b2",
) or str(model).startswith("https:"):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
Expand Down Expand Up @@ -198,7 +219,10 @@ def mace_off(
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser()
/ ".cache/mace"
)
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
Expand Down
5 changes: 4 additions & 1 deletion mace/tools/multihead_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import urllib.request
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -103,7 +104,9 @@ def assemble_mp_data(
try:
checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz"
descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy"
cache_dir = os.path.expanduser("~/.cache/mace")
cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace"
)
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
Expand Down

0 comments on commit fa15322

Please sign in to comment.