Skip to content

Commit

Permalink
added docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed Jun 25, 2024
1 parent 2e265f6 commit 36012e4
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions crystal_diffusion/models/mace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,47 @@ def get_normalized_irreps_permutation_indices(irreps: o3.Irreps) -> Tuple[o3.Irr
return sorted_irreps, column_permutation_indices


def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps):
def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor:
"""Reshape a MACE input/output tensor to a e3nn.NormActivation compatible format.
MACE uses tensors in the 2D format (ignoring the nodes / batchsize):
---- l = 0 ----
---- l = 1 ----
---- l = 1 ----
---- l = 1 ----
...
And e3nn wants a tensor in the 1D format:
---- l = 0 ---- ---- l= 1 ---- ---- l=2 ---- ...
Args:
x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2
irreps: o3 irreps matching the x tensor
Returns:
tensor of size (number of nodes, number of channels * (ell_max + 1)^2) usable by e3nn
"""
node = x.size(0)
# x : node, channel, irreps index
x_ = []
for ell in range(irreps.lmax + 1):
# for example, for l=1, take indices 1, 2, 3 (in the last index) and flatten as a channel * 3 tensor
x_l = x[:, :, (ell ** 2):(ell + 1)**2].reshape(node, -1) # node, channel * (2l + 1)
x_.append(x_l)
# stack the flatten irrep tensors together
return torch.cat(x_, dim=-1)


def reshape_from_e3nn_to_mace(x, irreps):
def reshape_from_e3nn_to_mace(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor:
"""Reshape a tensor in the e3nn.NormActivation format to a MACE format.
See reshape_from_mace_to_e3nn for an explanation of the formats
Args:
x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2
irreps: o3 irreps matching the x tensor
Returns:
tensor of size (number of nodes, number of channels * (ell_max + 1)^2) usable by e3nn
"""
node = x.size(0)
x_ = []
for ell, s in enumerate(irreps.slices()):
Expand Down

0 comments on commit 36012e4

Please sign in to comment.