Skip to content

Commit

Permalink
Merge branch 'main' into cueq-support-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 20, 2024
2 parents 4c1ad89 + 53cb6af commit 315d7c5
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 71 deletions.
12 changes: 11 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ase.stress import full_3x3_to_voigt_6_stress

from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
Expand Down Expand Up @@ -60,10 +61,13 @@ def __init__(
model_type="MACE",
compile_mode=None,
fullgraph=True,
enable_cueq=False,
**kwargs,
):
Calculator.__init__(self, **kwargs)

if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None
if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
Expand Down Expand Up @@ -130,6 +134,12 @@ def __init__(
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models
]

elif models is not None:
if not isinstance(models, list):
Expand Down
23 changes: 11 additions & 12 deletions mace/cli/convert_cueq_e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def transfer_symmetric_contractions(
):
"""Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation)
logging.info(
logging.warning(
f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}"
)

Expand Down Expand Up @@ -95,15 +95,15 @@ def transfer_weights(

# Transfer main weights
transfer_keys = get_transfer_keys()
logging.info("Transferring main weights...")
logging.warning("Transferring main weights...")
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")

# Transfer symmetric contractions
logging.info("Transferring symmetric contractions...")
logging.warning("Transferring symmetric contractions...")
transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation)

# Transfer remaining matching keys
Expand All @@ -114,7 +114,7 @@ def transfer_weights(
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}

if remaining_keys:
logging.info(
logging.warning(
f"Found {len(remaining_keys)} additional matching keys to transfer"
)
for key in remaining_keys:
Expand All @@ -138,24 +138,23 @@ def transfer_weights(


def run(input_model, output_model="_e3nn.model", device="cuda", return_model=True):
# Setup logging
logging.basicConfig(level=logging.INFO)

# Load CuEq model
logging.info(f"Loading CuEq model from {input_model}")
logging.warning("Loading CuEq model")
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model

default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
logging.info("Extracting model configuration")
logging.warning("Extracting model configuration")
config = extract_config_mace_model(source_model)

# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
logging.info(f"Extracted max_L={max_L}, correlation={correlation}")
logging.warning(f"Extracted max_L={max_L}, correlation={correlation}")

# Remove CuEq config
config.pop("cueq_config", None)
Expand All @@ -165,7 +164,7 @@ def run(input_model, output_model="_e3nn.model", device="cuda", return_model=Tru
target_model = source_model.__class__(**config)

# Transfer weights with proper remapping
logging.info("Transferring weights with remapping...")
logging.warning("Transferring weights with remapping...")
transfer_weights(source_model, target_model, max_L, correlation)

if return_model:
Expand All @@ -175,7 +174,7 @@ def run(input_model, output_model="_e3nn.model", device="cuda", return_model=Tru
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.info(f"Saving E3nn model to {output_model}")
logging.warning(f"Saving E3nn model to {output_model}")
torch.save(target_model, output_model)
return None

Expand Down
27 changes: 13 additions & 14 deletions mace/cli/convert_e3nn_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def transfer_symmetric_contractions(
):
"""Transfer symmetric contraction weights"""
kmax_pairs = get_kmax_pairs(max_L, correlation)
logging.info(
logging.warning(
f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}"
)

Expand All @@ -63,7 +63,7 @@ def transfer_symmetric_contractions(
for j in ["_max", ".0", ".1"]
],
dim=1,
) # .float()
)
target_dict[f"products.{i}.symmetric_contractions.weight"] = wm


Expand All @@ -80,25 +80,24 @@ def transfer_weights(

# Transfer main weights
transfer_keys = get_transfer_keys()
logging.info("Transferring main weights...")
logging.warning("Transferring main weights...")
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")

# Transfer symmetric contractions
logging.info("Transferring symmetric contractions...")
logging.warning("Transferring symmetric contractions...")
transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation)

transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}

if remaining_keys:
logging.info(
logging.warning(
f"Found {len(remaining_keys)} additional matching keys to transfer"
)
for key in remaining_keys:
Expand Down Expand Up @@ -127,24 +126,24 @@ def run(
return_model=True,
):
# Setup logging
logging.basicConfig(level=logging.INFO)

# Load original model
logging.info(f"Loading model from {input_model}")
# logging.warning(f"Loading model")
# check if input_model is a path or a model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model

default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
logging.info("Extracting model configuration")
logging.warning("Extracting model configuration")
config = extract_config_mace_model(source_model)

# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
logging.info(f"Extracted max_L={max_L}, correlation={correlation}")
logging.warning(f"Extracted max_L={max_L}, correlation={correlation}")

# Add cuequivariance config
config["cueq_config"] = CuEquivarianceConfig(
Expand All @@ -156,10 +155,10 @@ def run(

# Create new model with cuequivariance config
logging.info("Creating new model with cuequivariance settings")
target_model = source_model.__class__(**config)
target_model = source_model.__class__(**config).to(device)

# Transfer weights with proper remapping
logging.info("Transferring weights with remapping...")
logging.warning("Transferring weights with remapping...")
transfer_weights(source_model, target_model, max_L, correlation)

if return_model:
Expand All @@ -168,7 +167,7 @@ def run(
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.info(f"Saving CuEq model to {output_model}")
logging.warning(f"Saving CuEq model to {output_model}")
torch.save(target_model, output_model)
return None

Expand Down
32 changes: 18 additions & 14 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import mace
from mace import data, tools
from mace.calculators.foundations_models import mace_mp, mace_off
from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.tools import torch_geometric
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
Expand Down Expand Up @@ -54,8 +56,6 @@
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable
from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq


def main() -> None:
Expand Down Expand Up @@ -551,6 +551,11 @@ def run(args: argparse.Namespace) -> None:
logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
logging.info(loss_fn)

# Cueq
if args.enable_cueq:
logging.info("Converting model to CUEQ for accelerated training")
assert args.model in ["MACE", "ScaleShiftMACE"], "Model must be MACE or ScaleShiftMACE"
model = run_e3nn_to_cueq(deepcopy(model), device=device)
# Optimizer
param_options = get_params_options(args, model)
optimizer: torch.optim.Optimizer
Expand Down Expand Up @@ -602,10 +607,6 @@ def run(args: argparse.Namespace) -> None:

if args.wandb:
setup_wandb(args)
if args.enable_cueq:
logging.info("Converting model to CUEQ for accelerated training")
assert args.model in ["MACE", "ScaleShiftMACE"], "Model must be MACE or ScaleShiftMACE"
model = run_e3nn_to_cueq(model)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
else:
Expand Down Expand Up @@ -757,16 +758,19 @@ def run(args: argparse.Namespace) -> None:

if rank == 0:
# Save entire model
if args.enable_cueq:
model = run_cueq_to_e3nn(model)
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model")
else:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
model_to_save = deepcopy(model)
if args.enable_cueq:
print("RUNING CUEQ TO E3NN")
print("swa_eval", swa_eval)
model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device)
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)
model_to_save = model_to_save.to("cpu")
torch.save(model_to_save, model_path)
extra_files = {
"commit.txt": commit.encode("utf-8") if commit is not None else b"",
"config.yaml": json.dumps(
Expand All @@ -775,14 +779,14 @@ def run(args: argparse.Namespace) -> None:
}
if swa_eval:
torch.save(
model, Path(args.model_dir) / (args.name + "_stagetwo.model")
model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model")
)
try:
path_complied = Path(args.model_dir) / (
args.name + "_stagetwo_compiled.model"
)
logging.info(f"Compiling model, saving metadata {path_complied}")
model_compiled = jit.compile(deepcopy(model))
model_compiled = jit.compile(deepcopy(model_to_save))
torch.jit.save(
model_compiled,
path_complied,
Expand All @@ -791,13 +795,13 @@ def run(args: argparse.Namespace) -> None:
except Exception as e: # pylint: disable=W0703
pass
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))
torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model"))
try:
path_complied = Path(args.model_dir) / (
args.name + "_compiled.model"
)
logging.info(f"Compiling model, saving metadata to {path_complied}")
model_compiled = jit.compile(deepcopy(model))
model_compiled = jit.compile(deepcopy(model_to_save))
torch.jit.save(
model_compiled,
path_complied,
Expand Down
Loading

0 comments on commit 315d7c5

Please sign in to comment.