Skip to content

Commit

Permalink
Merge pull request #709 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Add cuequivariance support
  • Loading branch information
ilyes319 authored Nov 22, 2024
2 parents bd41231 + 28dba59 commit 67ec3b3
Show file tree
Hide file tree
Showing 20 changed files with 1,418 additions and 105 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ repos:
'--disable=cell-var-from-loop',
'--disable=duplicate-code',
'--disable=use-dict-literal',
'--max-module-lines=1500',
]
exclude: *exclude_files
5 changes: 5 additions & 0 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
"small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model",
"medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
}

checkpoint_url = (
Expand Down
12 changes: 11 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ase.stress import full_3x3_to_voigt_6_stress

from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
Expand Down Expand Up @@ -60,10 +61,13 @@ def __init__(
model_type="MACE",
compile_mode=None,
fullgraph=True,
enable_cueq=False,
**kwargs,
):
Calculator.__init__(self, **kwargs)

if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None
if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
Expand Down Expand Up @@ -130,6 +134,12 @@ def __init__(
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models
]

elif models is not None:
if not isinstance(models, list):
Expand Down
193 changes: 193 additions & 0 deletions mace/cli/convert_cueq_e3nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import argparse
import logging
import os
from typing import Dict, List, Tuple

import torch

from mace.tools.scripts_utils import extract_config_mace_model


def get_transfer_keys() -> List[str]:
"""Get list of keys that need to be transferred"""
return [
"node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight",
"scale_shift.scale",
"scale_shift.shift",
*[f"readouts.1.linear_{i}.weight" for i in range(1, 3)],
] + [
s
for j in range(2)
for s in [
f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight",
]
]


def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation"""
if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3:
return [[0, max_L], [1, 0]]
raise NotImplementedError(f"Correlation {correlation} not supported")


def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor],
max_L: int,
correlation: int,
):
"""Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation)

for i, kmax in kmax_pairs:
# Get the combined weight tensor from source
wm = source_dict[f"products.{i}.symmetric_contractions.weight"]

# Get split sizes based on target dimensions
splits = []
for k in range(kmax + 1):
for suffix in ["_max", ".0", ".1"]:
key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}"
target_shape = target_dict[key].shape
splits.append(target_shape[1])

# Split the weights using the calculated sizes
weights_split = torch.split(wm, splits, dim=1)

# Assign back to target dictionary
idx = 0
for k in range(kmax + 1):
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights_max"
] = weights_split[idx]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.0"
] = weights_split[idx + 1]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.1"
] = weights_split[idx + 2]
idx += 3


def transfer_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
max_L: int,
correlation: int,
):
"""Transfer weights from CuEq to E3nn format"""
# Get state dicts
source_dict = source_model.state_dict()
target_dict = target_model.state_dict()

# Transfer main weights
transfer_keys = get_transfer_keys()
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")

# Transfer symmetric contractions
transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation)

# Transfer remaining matching keys
transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}

if remaining_keys:
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key]
else:
logging.warning(
f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
)

# Transfer avg_num_neighbors
for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i
].avg_num_neighbors

# Load state dict into target model
target_model.load_state_dict(target_dict)


def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True):

# Load CuEq model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
config = extract_config_mace_model(source_model)

# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]

# Remove CuEq config
config.pop("cueq_config", None)

# Create new model without CuEq config
logging.info("Creating new model without CuEq settings")
target_model = source_model.__class__(**config)

# Transfer weights with proper remapping
transfer_weights(source_model, target_model, max_L, correlation)

if return_model:
return target_model

# Save model
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.warning(f"Saving E3nn model to {output_model}")
torch.save(target_model, output_model)
return None


def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input CuEq model")
parser.add_argument(
"--output_model", help="Path to output E3nn model", default="e3nn_model.pt"
)
parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument(
"--return_model",
action="store_false",
help="Return model instead of saving to file",
)
args = parser.parse_args()

run(
input_model=args.input_model,
output_model=args.output_model,
device=args.device,
return_model=args.return_model,
)


if __name__ == "__main__":
main()
Loading

0 comments on commit 67ec3b3

Please sign in to comment.