Skip to content

Commit

Permalink
allow custom cache based on XDG_CACHE_HOME env variable
Browse files Browse the repository at this point in the history
  • Loading branch information
alinelena committed Dec 17, 2024
1 parent c8f2d61 commit 86c075e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import urllib.request
from pathlib import Path
from typing import Union
from pathlib import Path

import torch
from ase import units
Expand Down Expand Up @@ -46,7 +47,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
else model
)

cache_dir = os.path.expanduser("~/.cache/mace")
cache_dir = Path(os.environ.get('XDG_CACHE_HOME',"~/")).expanduser() / Path(".cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
Expand Down Expand Up @@ -198,7 +199,7 @@ def mace_off(
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
cache_dir = Path(os.environ.get('XDG_CACHE_HOME',"~/")).expanduser() / Path(".cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
Expand Down
3 changes: 2 additions & 1 deletion mace/tools/multihead_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import urllib.request
from typing import Any, Dict, List, Optional, Union
from pathlib import Path

import torch

Expand Down Expand Up @@ -103,7 +104,7 @@ def assemble_mp_data(
try:
checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz"
descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy"
cache_dir = os.path.expanduser("~/.cache/mace")
cache_dir = Path(os.environ.get('XDG_CACHE_HOME',"~/")).expanduser() / Path(".cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
Expand Down

0 comments on commit 86c075e

Please sign in to comment.