Skip to content

Commit

Permalink
change default cueq to ir_mul
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 21, 2024
1 parent 9dca489 commit a200571
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
2 changes: 1 addition & 1 deletion mace/cli/convert_e3nn_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run(
# Add cuequivariance config
config["cueq_config"] = CuEquivarianceConfig(
enabled=True,
layout="mul_ir",
layout="ir_mul",
group="O3_e3nn",
optimize_all=True,
)
Expand Down
5 changes: 0 additions & 5 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def _setup(self) -> None:
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
Expand Down Expand Up @@ -717,7 +716,6 @@ def _setup(self) -> None:
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
Expand Down Expand Up @@ -800,7 +798,6 @@ def _setup(self) -> None:
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
Expand Down Expand Up @@ -898,7 +895,6 @@ def _setup(self) -> None:
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
Expand Down Expand Up @@ -1007,7 +1003,6 @@ def _setup(self) -> None:
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = Linear(
irreps_mid,
Expand Down
6 changes: 4 additions & 2 deletions mace/modules/wrapper_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def __new__(
cue.Irreps(cueq_config.group, irreps_out),
layout=cueq_config.layout,
shared_weights=shared_weights,
optimize_fallback=True,
)
instance.original_forward = instance.forward

def cuet_forward(self, x: torch.Tensor) -> torch.Tensor:
return self.original_forward(x, use_fallback=None)
return self.original_forward(x, use_fallback=True)

instance.forward = types.MethodType(cuet_forward, instance)
return instance
Expand Down Expand Up @@ -193,13 +194,14 @@ def __new__(
layout=cueq_config.layout,
shared_weights=shared_weights,
internal_weights=internal_weights,
optimize_fallback=True,
)
instance.original_forward = instance.forward

def cuet_forward(
self, x: torch.Tensor, attrs: torch.Tensor
) -> torch.Tensor:
return self.original_forward(x, attrs, use_fallback=None)
return self.original_forward(x, attrs, use_fallback=True)

instance.forward = types.MethodType(cuet_forward, instance)
return instance
Expand Down

0 comments on commit a200571

Please sign in to comment.