Skip to content

Commit

Permalink
Merge branch 'ACEsuit:main' into lbfgs-optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa authored Nov 17, 2024
2 parents 810ac05 + bd41231 commit 98fb1c3
Show file tree
Hide file tree
Showing 26 changed files with 1,123 additions and 308 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[![License](https://img.shields.io/badge/License-MIT%202.0-blue.svg)](https://opensource.org/licenses/mit)
[![GitHub issues](https://img.shields.io/github/issues/ACEsuit/mace.svg)](https://GitHub.com/ACEsuit/mace/issues/)
[![Documentation Status](https://readthedocs.org/projects/mace/badge/)](https://mace-docs.readthedocs.io/en/latest/)
[![DOI](https://zenodo.org/badge/505964914.svg)](https://doi.org/10.5281/zenodo.14103332)

## Table of contents

Expand Down
2 changes: 1 addition & 1 deletion mace/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.3.7"
__version__ = "0.3.8"

__all__ = ["__version__"]
71 changes: 42 additions & 29 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,15 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
model_path = model
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

Expand Down Expand Up @@ -173,36 +180,42 @@ def mace_off(
MACECalculator: trained on the MACE-OFF23 dataset
"""
try:
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
except Exception as exc:
raise RuntimeError("Model download failed") from exc
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
mode=compile_mode,
fullgraph=fullgraph,
)
for model in models
for model in self.models
]
self.use_compile = True
else:
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
Expand Down
31 changes: 31 additions & 0 deletions mace/cli/convert_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from argparse import ArgumentParser

import torch


def main():
parser = ArgumentParser()
parser.add_argument(
"--target_device",
"-t",
help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model.to(args.target_device)
torch.save(model, args.output_file)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def parse_args():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"model_path",
type=str,
Expand Down
21 changes: 19 additions & 2 deletions mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
Expand Down Expand Up @@ -53,6 +55,13 @@ def parse_args() -> argparse.Namespace:
type=str,
default="MACE_",
)
parser.add_argument(
"--head",
help="Model head used for evaluation",
type=str,
required=False,
default=None,
)
return parser.parse_args()


Expand All @@ -76,14 +85,22 @@ def run(args: argparse.Namespace) -> None:

# Load data and prepare input
atoms_list = ase.io.read(args.configs, index=":")
if args.head is not None:
for atoms in atoms_list:
atoms.info["head"] = args.head
configs = [data.config_from_atoms(atoms) for atoms in atoms_list]

z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])

try:
heads = model.heads
except AttributeError:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max)
config, z_table=z_table, cutoff=float(model.r_max), heads=heads
)
for config in configs
],
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--configs_pt",
help="path to XYZ configurations for the pretraining",
Expand Down
5 changes: 4 additions & 1 deletion mace/cli/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]:


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot mace training statistics")
parser = argparse.ArgumentParser(
description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--path", help="path to results file or directory", required=True
)
Expand Down
46 changes: 37 additions & 9 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
LRScheduler,
check_path_ase_read,
convert_to_json_format,
create_error_table,
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
Expand All @@ -49,9 +48,11 @@
get_params_options,
get_swa,
print_git_commit,
remove_pt_head,
setup_wandb,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable


Expand Down Expand Up @@ -115,10 +116,6 @@ def run(args: argparse.Namespace) -> None:
commit = print_git_commit()
model_foundation: Optional[torch.nn.Module] = None
if args.foundation_model is not None:
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
if args.foundation_model in ["small", "medium", "large"]:
logging.info(
f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint."
Expand Down Expand Up @@ -148,6 +145,27 @@ def run(args: argparse.Namespace) -> None:
f"Using foundation model {args.foundation_model} as initial checkpoint."
)
args.r_max = model_foundation.r_max.item()
if (
args.foundation_model not in ["small", "medium", "large"]
and args.pt_train_file is None
):
logging.warning(
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
)
args.multiheads_finetuning = False
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
)
model_foundation = remove_pt_head(
model_foundation, args.foundation_head
)
else:
args.multiheads_finetuning = False

Expand Down Expand Up @@ -353,8 +371,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict[head_config.head_name] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand All @@ -372,8 +396,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict["pt_head"] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand Down Expand Up @@ -575,7 +605,6 @@ def run(args: argparse.Namespace) -> None:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None

tools.train(
model=model,
loss_fn=loss_fn,
Expand Down Expand Up @@ -654,7 +683,6 @@ def run(args: argparse.Namespace) -> None:
folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
)
for test_name, test_set in test_sets.items():
print(test_name)
test_sampler = None
if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(
Expand Down
33 changes: 33 additions & 0 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from argparse import ArgumentParser

import torch

from mace.tools.scripts_utils import remove_pt_head


def main():
parser = ArgumentParser()
parser.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
ResidualElementDependentInteractionBlock,
Expand Down Expand Up @@ -56,6 +58,8 @@
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
}

scaling_classes: Dict[str, Callable] = {
Expand Down
Loading

0 comments on commit 98fb1c3

Please sign in to comment.