Skip to content

Commit

Permalink
finishing up sym/non-sym cp/tucker and small clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Oct 28, 2024
1 parent 09b9344 commit 7712975
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 171 deletions.
32 changes: 9 additions & 23 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def tensor_power_einsum(tensor, N):

@compile_mode("script")
class TensorFormatBlock(torch.nn.Module):
"""
Maybe useful for reshaping tensor for efficient operation later.
"""
def __init__(self, tensor_format, correlation):
super().__init__()

Expand All @@ -239,17 +242,10 @@ def __init__(self, tensor_format, correlation):

def forward(self, message) -> torch.Tensor:
batch_size, dim, features = message.shape
if self.tensor_format == "symmetric_cp":
if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]:
return message
elif self.tensor_format == "symmetric_tucker":
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
return message
# message = [message] * correlation
# message = torch.einsum(eq, *tensors)
# # K = message.shape[-2]
# # for i in range(self.correlation - 1):
# # message = message.unsqueeze(-2)
# # message = message.repeat([1, 1, ] + [K, ] * (self.correlation - 1) + [1])
# return message.reshape(batch_size, dim ** correlation, features)


@compile_mode("script")
Expand Down Expand Up @@ -281,7 +277,7 @@ def __init__(
internal_weights=True,
shared_weights=True,
)
elif tensor_format == "symmetric_tucker":
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]:
tucker_irreps = make_tucker_irreps(target_irreps, correlation)
self.linear = o3.Linear(
tucker_irreps,
Expand Down Expand Up @@ -687,7 +683,7 @@ def _setup(self) -> None:
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps

if self.tensor_format == "symmetric_cp":
if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]:
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
Expand All @@ -697,7 +693,7 @@ def _setup(self) -> None:
)
self.reshape = reshape_irreps(self.irreps_out)

elif self.tensor_format == "non_symmetric_cp":
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
self.linear = []
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
Expand All @@ -710,16 +706,6 @@ def _setup(self) -> None:
))
self.reshape.append(reshape_irreps(self.irreps_out))

elif self.tensor_format == "symmetric_tucker":
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)

# self.reshape = reshape_irreps(self.irreps_out)
self.tensor_format_layer = TensorFormatBlock(self.tensor_format, self.correlation)

Expand Down Expand Up @@ -750,7 +736,7 @@ def forward(
self.tensor_format_layer(self.reshape(message)),
sc,
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
elif self.tensor_format == "non_symmetric_cp":
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
message = self.reshape[0](self.linear[0](original_message))
message = message.unsqueeze(-1)
for idx in range(1, self.correlation):
Expand Down
2 changes: 1 addition & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
elif tensor_format == "symmetric_tucker":
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]:
self.readouts.append(
#LinearReadoutBlock(make_tp_irreps(hidden_irreps, correlation[0]), o3.Irreps(f"{len(heads)}x0e"))
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
Expand Down
Loading

0 comments on commit 7712975

Please sign in to comment.