Skip to content

Commit

Permalink
add test_run_train_cueq
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 19, 2024
1 parent 2e8e5c5 commit ef42dba
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 144 deletions.
9 changes: 8 additions & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable
from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq


def main() -> None:
Expand Down Expand Up @@ -600,7 +602,10 @@ def run(args: argparse.Namespace) -> None:

if args.wandb:
setup_wandb(args)

if args.enable_cueq:
logging.info("Converting model to CUEQ for accelerated training")
assert args.model in ["MACE", "ScaleShiftMACE"], "Model must be MACE or ScaleShiftMACE"
model = run_e3nn_to_cueq(model)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
else:
Expand Down Expand Up @@ -752,6 +757,8 @@ def run(args: argparse.Namespace) -> None:

if rank == 0:
# Save entire model
if args.enable_cueq:
model = run_cueq_to_e3nn(model)
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model")
else:
Expand Down
46 changes: 1 addition & 45 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,55 +662,11 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
)
# option for cuequivariance acceleration
parser.add_argument(
"--cue_enabled",
"--enable_cueq",
help="Enable cuequivariance acceleration",
type=str2bool,
default=False,
)
parser.add_argument(
"--cue_layout",
help="Memory layout for cuequivariance tensors",
type=str,
choices=["mul_ir", "ir_mul"],
default="mul_ir",
)
parser.add_argument(
"--cue_group",
help="Symmetry group for cuequivariance",
type=str,
choices=["O3_e3nn, O3"],
default="O3_e3nn",
)
parser.add_argument(
"--cue_optimize_all",
help="Enable all cuequivariance optimizations",
type=str2bool,
default=False,
)
parser.add_argument(
"--cue_optimize_linear",
help="Enable cuequivariance linear layer optimization",
type=str2bool,
default=False,
)
parser.add_argument(
"--cue_optimize_channelwise",
help="Enable cuequivariance channelwise optimization",
type=str2bool,
default=False,
)
parser.add_argument(
"--cue_optimize_symmetric",
help="Enable cuequivariance symmetric contraction optimization",
type=str2bool,
default=False,
)
parser.add_argument(
"--cue_optimize_fctp",
help="Enable cuequivariance fully connected tensor product optimization",
type=str2bool,
default=False,
)
# options for using Weights and Biases for experiment tracking
# to install see https://wandb.ai
parser.add_argument(
Expand Down
11 changes: 0 additions & 11 deletions mace/tools/model_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,6 @@ def configure_model(
logging.info(
f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}"
)
cueq_config = CuEquivarianceConfig(
enabled=args.cue_enabled,
layout=args.cue_layout,
group=args.cue_group,
optimize_all=args.cue_optimize_all,
optimize_linear=args.cue_optimize_linear,
optimize_channelwise=args.cue_optimize_channelwise,
optimize_symmetric=args.cue_optimize_symmetric,
optimize_fctp=args.cue_optimize_fctp,
)
logging.info("===========MODEL DETAILS===========")

if args.scaling == "no_scaling":
Expand Down Expand Up @@ -120,7 +110,6 @@ def configure_model(
atomic_energies=atomic_energies,
avg_num_neighbors=args.avg_num_neighbors,
atomic_numbers=z_table.zs,
cueq_config=cueq_config,
)
model_config_foundation = None

Expand Down
174 changes: 87 additions & 87 deletions tests/test_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,79 +94,7 @@ def batch(self, device: str):
o3.Irreps("32x0e"),
],
)
# def test_bidirectional_conversion(
# self,
# model_config: Dict[str, Any],
# batch: Dict[str, torch.Tensor],
# device: str,
# ):
# torch.manual_seed(42)

# # Create original E3nn model
# model_e3nn = modules.ScaleShiftMACE(**model_config)
# model_e3nn = model_e3nn.to(device)

# # Convert E3nn to CuEq
# model_cueq = run_e3nn_to_cueq(model_e3nn)
# model_cueq = model_cueq.to(device)

# # Convert CuEq back to E3nn
# model_e3nn_back = run_cueq_to_e3nn(model_cueq)
# model_e3nn_back = model_e3nn_back.to(device)

# # Test forward pass equivalence
# out_e3nn = model_e3nn(batch, training=True)
# out_cueq = model_cueq(batch, training=True)
# out_e3nn_back = model_e3nn_back(batch, training=True)

# # Check outputs match for both conversions
# torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])
# torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"])
# torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"])
# torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"])

# # Test backward pass equivalence
# loss_e3nn = out_e3nn["energy"].sum()
# loss_cueq = out_cueq["energy"].sum()
# loss_e3nn_back = out_e3nn_back["energy"].sum()

# loss_e3nn.backward()
# loss_cueq.backward()
# loss_e3nn_back.backward()

# # Compare gradients for all conversions
# def print_gradient_diff(name1, p1, name2, p2, conv_type):
# if p1.grad is not None and p1.grad.shape == p2.grad.shape:
# if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]:
# error = torch.abs(p1.grad - p2.grad)
# print(
# f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}"
# )
# torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10)

# # E3nn to CuEq gradients
# for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip(
# model_e3nn.named_parameters(), model_cueq.named_parameters()
# ):
# print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq")

# # CuEq to E3nn gradients
# for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip(
# model_cueq.named_parameters(), model_e3nn_back.named_parameters()
# ):
# print_gradient_diff(
# name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn"
# )

# # Full circle comparison (E3nn -> E3nn)
# for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip(
# model_e3nn.named_parameters(), model_e3nn_back.named_parameters()
# ):
# print_gradient_diff(
# name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle"
# )

def test_jit_compile(
def test_bidirectional_conversion(
self,
model_config: Dict[str, Any],
batch: Dict[str, torch.Tensor],
Expand All @@ -186,28 +114,100 @@ def test_jit_compile(
model_e3nn_back = run_cueq_to_e3nn(model_cueq)
model_e3nn_back = model_e3nn_back.to(device)

# # Compile all models
model_e3nn_compiled = jit.compile(model_e3nn)
model_cueq_compiled = jit.compile(model_cueq)
model_e3nn_back_compiled = jit.compile(model_e3nn_back)

# Test forward pass equivalence
out_e3nn = model_e3nn(batch, training=True)
out_cueq = model_cueq(batch, training=True)
out_e3nn_back = model_e3nn_back(batch, training=True)

out_e3nn_compiled = model_e3nn_compiled(batch, training=True)
out_cueq_compiled = model_cueq_compiled(batch, training=True)
out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True)

# Check outputs match for both conversions
torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])
torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"])
torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"])
torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"])

torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"])
torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"])
torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"])
torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"])
torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"])
# Test backward pass equivalence
loss_e3nn = out_e3nn["energy"].sum()
loss_cueq = out_cueq["energy"].sum()
loss_e3nn_back = out_e3nn_back["energy"].sum()

loss_e3nn.backward()
loss_cueq.backward()
loss_e3nn_back.backward()

# Compare gradients for all conversions
def print_gradient_diff(name1, p1, name2, p2, conv_type):
if p1.grad is not None and p1.grad.shape == p2.grad.shape:
if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]:
error = torch.abs(p1.grad - p2.grad)
print(
f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}"
)
torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10)

# E3nn to CuEq gradients
for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip(
model_e3nn.named_parameters(), model_cueq.named_parameters()
):
print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq")

# CuEq to E3nn gradients
for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip(
model_cueq.named_parameters(), model_e3nn_back.named_parameters()
):
print_gradient_diff(
name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn"
)

# Full circle comparison (E3nn -> E3nn)
for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip(
model_e3nn.named_parameters(), model_e3nn_back.named_parameters()
):
print_gradient_diff(
name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle"
)

# def test_jit_compile(
# self,
# model_config: Dict[str, Any],
# batch: Dict[str, torch.Tensor],
# device: str,
# ):
# torch.manual_seed(42)

# # Create original E3nn model
# model_e3nn = modules.ScaleShiftMACE(**model_config)
# model_e3nn = model_e3nn.to(device)

# # Convert E3nn to CuEq
# model_cueq = run_e3nn_to_cueq(model_e3nn)
# model_cueq = model_cueq.to(device)

# # Convert CuEq back to E3nn
# model_e3nn_back = run_cueq_to_e3nn(model_cueq)
# model_e3nn_back = model_e3nn_back.to(device)

# # # Compile all models
# model_e3nn_compiled = jit.compile(model_e3nn)
# model_cueq_compiled = jit.compile(model_cueq)
# model_e3nn_back_compiled = jit.compile(model_e3nn_back)

# # Test forward pass equivalence
# out_e3nn = model_e3nn(batch, training=True)
# out_cueq = model_cueq(batch, training=True)
# out_e3nn_back = model_e3nn_back(batch, training=True)

# out_e3nn_compiled = model_e3nn_compiled(batch, training=True)
# out_cueq_compiled = model_cueq_compiled(batch, training=True)
# out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True)

# # Check outputs match for both conversions
# torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])
# torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"])
# torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"])
# torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"])

# torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"])
# torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"])
# torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"])
# torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"])
# torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"])
67 changes: 67 additions & 0 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,70 @@ def test_run_train_multihead_replay_custum_finetuning(
assert len(Es) == len(fitting_configs)
assert all(isinstance(E, float) for E in Es)
assert len(set(Es)) > 1 # Ens

def test_run_train_cueq(tmp_path, fitting_configs):
ase.io.write(tmp_path / "fit.xyz", fitting_configs)

mace_params = _mace_params.copy()
mace_params["checkpoints_dir"] = str(tmp_path)
mace_params["model_dir"] = str(tmp_path)
mace_params["train_file"] = tmp_path / "fit.xyz"
mace_params["enable_cueq"] = True

# make sure run_train.py is using the mace that is currently being tested
run_env = os.environ.copy()
sys.path.insert(0, str(Path(__file__).parent.parent))
run_env["PYTHONPATH"] = ":".join(sys.path)
print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"])

cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)

p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu")

Es = []
for at in fitting_configs:
at.calc = calc
Es.append(at.get_potential_energy())

print("Es", Es)
# from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7
ref_Es = [
0.0,
0.0,
-0.039181344585828524,
-0.0915223395136733,
-0.14953484236456582,
-0.06662480820063998,
-0.09983737353050133,
0.12477442296789745,
-0.06486086271762856,
-0.1460607988519944,
0.12886334908465508,
-0.14000990081920373,
-0.05319886578958313,
0.07780520158391,
-0.08895480281886901,
-0.15474719614734422,
0.007756765146527644,
-0.044879267197498685,
-0.036065736712447574,
-0.24413743841886623,
-0.0838104612106429,
-0.14751978636626545,
]

assert np.allclose(Es, ref_Es)

0 comments on commit ef42dba

Please sign in to comment.