Skip to content

Commit

Permalink
add float32 test cueq
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 22, 2024
1 parent f3c0939 commit 49fb8b2
Showing 1 changed file with 19 additions and 56 deletions.
75 changes: 19 additions & 56 deletions tests/test_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
except ImportError:
CUET_AVAILABLE = False

torch.set_default_dtype(torch.float64)
CUDA_AVAILABLE = torch.cuda.is_available()


@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed")
Expand Down Expand Up @@ -50,9 +50,11 @@ def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]:
}

@pytest.fixture
def batch(self, device: str):
def batch(self, device: str, default_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
from ase import build

torch.set_default_dtype(default_dtype)

table = tools.AtomicNumberTable([6])

atoms = build.bulk("C", "diamond", a=3.567, cubic=True)
Expand All @@ -75,7 +77,10 @@ def batch(self, device: str):
batch = next(iter(data_loader))
return batch.to(device).to_dict()

@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize(
"device",
["cpu"] + (["cuda"] if CUDA_AVAILABLE else []),
)
@pytest.mark.parametrize(
"interaction_cls_first",
[
Expand All @@ -92,24 +97,26 @@ def batch(self, device: str):
o3.Irreps("32x0e"),
],
)
@pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64])
def test_bidirectional_conversion(
self,
model_config: Dict[str, Any],
batch: Dict[str, torch.Tensor],
device: str,
default_dtype: torch.dtype,
):
if device == "cuda" and not CUDA_AVAILABLE:
pytest.skip("CUDA not available")
torch.manual_seed(42)

# Create original E3nn model
model_e3nn = modules.ScaleShiftMACE(**model_config)
# model_e3nn = model_e3nn.to(device)
model_e3nn = modules.ScaleShiftMACE(**model_config).to(device)

# Convert E3nn to CuEq
model_cueq = run_e3nn_to_cueq(model_e3nn)
# model_cueq = model_cueq.to(device)
model_cueq = run_e3nn_to_cueq(model_e3nn).to(device)

# Convert CuEq back to E3nn
model_e3nn_back = run_cueq_to_e3nn(model_cueq)
# model_e3nn_back = model_e3nn_back.to(device)
model_e3nn_back = run_cueq_to_e3nn(model_cueq).to(device)

# Test forward pass equivalence
out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True)
Expand All @@ -136,14 +143,16 @@ def test_bidirectional_conversion(
loss_e3nn_back.backward()

# Compare gradients for all conversions
tol = 1e-4 if default_dtype == torch.float32 else 1e-8

def print_gradient_diff(name1, p1, name2, p2, conv_type):
if p1.grad is not None and p1.grad.shape == p2.grad.shape:
if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]:
error = torch.abs(p1.grad - p2.grad)
print(
f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}"
)
torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10)
torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=1e-10)

# E3nn to CuEq gradients
for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip(
Expand All @@ -166,49 +175,3 @@ def print_gradient_diff(name1, p1, name2, p2, conv_type):
print_gradient_diff(
name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle"
)

# def test_jit_compile(
# self,
# model_config: Dict[str, Any],
# batch: Dict[str, torch.Tensor],
# device: str,
# ):
# torch.manual_seed(42)

# # Create original E3nn model
# model_e3nn = modules.ScaleShiftMACE(**model_config)
# model_e3nn = model_e3nn.to(device)

# # Convert E3nn to CuEq
# model_cueq = run_e3nn_to_cueq(model_e3nn)
# model_cueq = model_cueq.to(device)

# # Convert CuEq back to E3nn
# model_e3nn_back = run_cueq_to_e3nn(model_cueq)
# model_e3nn_back = model_e3nn_back.to(device)

# # # Compile all models
# model_e3nn_compiled = jit.compile(model_e3nn)
# model_cueq_compiled = jit.compile(model_cueq)
# model_e3nn_back_compiled = jit.compile(model_e3nn_back)

# # Test forward pass equivalence
# out_e3nn = model_e3nn(batch, training=True)
# out_cueq = model_cueq(batch, training=True)
# out_e3nn_back = model_e3nn_back(batch, training=True)

# out_e3nn_compiled = model_e3nn_compiled(batch, training=True)
# out_cueq_compiled = model_cueq_compiled(batch, training=True)
# out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True)

# # Check outputs match for both conversions
# torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])
# torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"])
# torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"])
# torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"])

# torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"])
# torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"])
# torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"])
# torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"])
# torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"])

0 comments on commit 49fb8b2

Please sign in to comment.