Skip to content

Commit

Permalink
Merge pull request #668 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Bug fix and density normalization
  • Loading branch information
ilyes319 authored Nov 4, 2024
2 parents 74dcd4c + 29c99ae commit 4081abd
Show file tree
Hide file tree
Showing 16 changed files with 446 additions and 51 deletions.
71 changes: 42 additions & 29 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,15 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
model_path = model
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

Expand Down Expand Up @@ -173,36 +180,42 @@ def mace_off(
MACECalculator: trained on the MACE-OFF23 dataset
"""
try:
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
except Exception as exc:
raise RuntimeError("Model download failed") from exc
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
mode=compile_mode,
fullgraph=fullgraph,
)
for model in models
for model in self.models
]
self.use_compile = True
else:
Expand Down
31 changes: 31 additions & 0 deletions mace/cli/convert_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from argparse import ArgumentParser

import torch


def main():
parser = ArgumentParser()
parser.add_argument(
"--target_device",
"-t",
help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model.to(args.target_device)
torch.save(model, args.output_file)


if __name__ == "__main__":
main()
17 changes: 16 additions & 1 deletion mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def parse_args() -> argparse.Namespace:
type=str,
default="MACE_",
)
parser.add_argument(
"--head",
help="Model head used for evaluation",
type=str,
required=False,
default=None,
)
return parser.parse_args()


Expand All @@ -76,14 +83,22 @@ def run(args: argparse.Namespace) -> None:

# Load data and prepare input
atoms_list = ase.io.read(args.configs, index=":")
if args.head is not None:
for atoms in atoms_list:
atoms.info["head"] = args.head
configs = [data.config_from_atoms(atoms) for atoms in atoms_list]

z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])

try:
heads = model.heads
except AttributeError:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max)
config, z_table=z_table, cutoff=float(model.r_max), heads=heads
)
for config in configs
],
Expand Down
17 changes: 14 additions & 3 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict[head_config.head_name] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand All @@ -372,8 +378,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict["pt_head"] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand Down Expand Up @@ -653,7 +665,6 @@ def run(args: argparse.Namespace) -> None:
folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
)
for test_name, test_set in test_sets.items():
print(test_name)
test_sampler = None
if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(
Expand Down
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
ResidualElementDependentInteractionBlock,
Expand Down Expand Up @@ -56,6 +58,8 @@
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
}

scaling_classes: Dict[str, Callable] = {
Expand Down
Loading

0 comments on commit 4081abd

Please sign in to comment.