From a20057153a17aca81fc6f02e2d0c6837a18a7da8 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:56:25 +0000 Subject: [PATCH] change default cueq to ir_mul --- mace/cli/convert_e3nn_cueq.py | 2 +- mace/modules/blocks.py | 5 ----- mace/modules/wrapper_ops.py | 6 ++++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py index 45e07257..45bd1555 100644 --- a/mace/cli/convert_e3nn_cueq.py +++ b/mace/cli/convert_e3nn_cueq.py @@ -138,7 +138,7 @@ def run( # Add cuequivariance config config["cueq_config"] = CuEquivarianceConfig( enabled=True, - layout="mul_ir", + layout="ir_mul", group="O3_e3nn", optimize_all=True, ) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 7bc3561f..ea0e228b 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -634,7 +634,6 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, @@ -717,7 +716,6 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, @@ -800,7 +798,6 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, @@ -898,7 +895,6 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, @@ -1007,7 +1003,6 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( irreps_mid, diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 3eb8120b..580b4a0a 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -104,11 +104,12 @@ def __new__( cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, + optimize_fallback=True, ) instance.original_forward = instance.forward def cuet_forward(self, x: torch.Tensor) -> torch.Tensor: - return self.original_forward(x, use_fallback=None) + return self.original_forward(x, use_fallback=True) instance.forward = types.MethodType(cuet_forward, instance) return instance @@ -193,13 +194,14 @@ def __new__( layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, + optimize_fallback=True, ) instance.original_forward = instance.forward def cuet_forward( self, x: torch.Tensor, attrs: torch.Tensor ) -> torch.Tensor: - return self.original_forward(x, attrs, use_fallback=None) + return self.original_forward(x, attrs, use_fallback=True) instance.forward = types.MethodType(cuet_forward, instance) return instance