Skip to content

Commit

Permalink
[Verified] remove gaussian prior, add DN sinu freq embedding and supp…
Browse files Browse the repository at this point in the history
…ort specifying tensor format for each layer
  • Loading branch information
CheukHinHoJerry committed Nov 22, 2024
1 parent f37a447 commit 63f813c
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 63 deletions.
210 changes: 189 additions & 21 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@
GaussianBasis,
PolynomialCutoff,
SoftTransform,
continuous_sinous_embedding,
)
from .symmetric_contraction import SymmetricContraction

from functools import partial

AGNOSTIC = False

@compile_mode("script")
class LinearNodeEmbeddingBlock(torch.nn.Module):
def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps):
Expand Down Expand Up @@ -228,26 +233,26 @@ def tensor_power_einsum(tensor, N):
result = result.reshape(batch_size, dim ** N, features)
return result

# @compile_mode("script")
# class TensorFormatBlock(torch.nn.Module):
# """
# Maybe useful for reshaping tensor for efficient operation later.
# """
# def __init__(self, tensor_format, correlation):
# super().__init__()
@compile_mode("script")
class TensorFormatBlock(torch.nn.Module):
"""
Maybe useful for reshaping tensor for efficient operation later.
"""
def __init__(self, tensor_format, correlation):
super().__init__()

# self.tensor_format = tensor_format
# self.correlation = correlation
# #self.irreps_in = irreps_in
# #self.indices = [chr(ord('a') + i) for i in range(N)]
# #self.eq = ','.join(['bi' + 'f' for _ in range(N)]) + '->b' + ''.join(indices) + 'f'
self.tensor_format = tensor_format
self.correlation = correlation
#self.irreps_in = irreps_in
#self.indices = [chr(ord('a') + i) for i in range(N)]
#self.eq = ','.join(['bi' + 'f' for _ in range(N)]) + '->b' + ''.join(indices) + 'f'

# def forward(self, message) -> torch.Tensor:
# batch_size, dim, features = message.shape
# if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:
# return message
# elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
# return message
def forward(self, message) -> torch.Tensor:
batch_size, dim, features = message.shape
if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:
return message
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
return message


@compile_mode("script")
Expand All @@ -260,6 +265,7 @@ def __init__(
learned_radials_dim: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
agnostic: Optional[bool] = False,
tensor_format = "symmetric_cp",
flexible_feats_L = False,
gaussian_prior = False,
Expand All @@ -272,6 +278,7 @@ def __init__(
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
agnostic=agnostic,
tensor_format=tensor_format,
flexible_feats_L=flexible_feats_L,
gaussian_prior=gaussian_prior,
Expand All @@ -280,6 +287,7 @@ def __init__(
if tensor_format in ["symmetric_cp", "non_symmetric_cp"]:
mid_irreps = target_irreps
elif tensor_format in ["flexible_non_symmetric_tucker", "flexible_symmetric_tucker",]:
print(target_irreps, correlation)
mid_irreps = make_tucker_irreps_flexible(target_irreps, correlation)
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]:
mid_irreps = make_tucker_irreps(target_irreps, correlation)
Expand All @@ -301,8 +309,7 @@ def forward(
node_feats = self.symmetric_contractions(node_feats, node_attrs)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)

return self.linear(node_feats)

@compile_mode("script")
class InteractionBlock(torch.nn.Module):
Expand All @@ -316,7 +323,9 @@ def __init__(
hidden_irreps: o3.Irreps,
avg_num_neighbors: float,
correlation: int,
gate: Optional[Callable] = torch.nn.functional.silu,
radial_MLP: Optional[List[int]] = None,
agnostic: Optional[bool] = False,
tensor_format: str = "symmetric_cp",
) -> None:
super().__init__()
Expand All @@ -330,6 +339,9 @@ def __init__(
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP
self.gate = gate
self.agnostic = agnostic
# === tensor format stuffs ===
self.tensor_format = tensor_format
self.correlation = correlation

Expand Down Expand Up @@ -740,7 +752,7 @@ def forward(
if self.tensor_format in ["symmetric_cp", "symmetric_tucker",]:
message = self.linear(original_message) / self.avg_num_neighbors
return (
self.tensor_format_layer(self.reshape(message)),
self.reshape(message),
sc,
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
elif self.tensor_format in ["flexible_symmetric_tucker"]:
Expand Down Expand Up @@ -991,6 +1003,162 @@ def forward(
)


@compile_mode("script")
class RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)

# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
print(f"RealAgnosticInteractionGateBlock --> {self.gate}")
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
self.gate,
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps

if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)

elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker", "flexible_non_symmetric_tucker"]:
self.linear = torch.nn.ModuleList([])
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = torch.nn.ModuleList([])
for _ in range(self.correlation):
self.linear.append(o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
))
self.reshape.append(reshape_irreps(self.irreps_out))

if not getattr(self, "agnostic", False):
ValueError("agnostic not supported yet inRealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock")
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
else:
## Selector TensorProduct
#self.skip_tp = o3.FullyConnectedTensorProduct(
# self.irreps_out, self.node_feats_irreps, self.irreps_out
#)
pass

self.density_fn = nn.FullyConnectedNet(
[input_dim] + [1,],
self.gate
)

self.sinous_embedding = partial(continuous_sinous_embedding, dim=32, max_density=100)
self.density_linear = torch.nn.Linear(32, self.irreps_out[0].mul, bias=False) # TODO: density embedding model

def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
# learnable density funciton with
density = torch.tanh(self.density_fn(edge_feats)**2)

# NO RESCALE
#mji = mji * density

message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]

node_local_density = scatter_sum(
src=density, index=receiver, dim=0, dim_size=num_nodes
)

message = message / (node_local_density + 1)

# density_embedding
sin_embedding = self.sinous_embedding(node_local_density.flatten())
density_embedding = self.density_linear(sin_embedding)
# density inject
message[:, self.irreps_out.slices()[0]] += density_embedding

# == tensor formats ===
original_message = message
if self.tensor_format in ["symmetric_cp", "symmetric_tucker",]:
message = self.linear(original_message) / (node_local_density + 1)
return (
self.reshape(message),
sc,
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
elif self.tensor_format in ["flexible_symmetric_tucker"]:
message = self.linear(original_message) / (node_local_density + 1)
# requires format contraction in SymmetricContraction - no reshape
# to [n_nodes, channels, (lmax + 1) ** 2 ] yet
return (message, sc)
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
message = self.reshape[0](self.linear[0](original_message) / (node_local_density + 1))
message = message.unsqueeze(-1)
for idx in range(1, self.correlation):
_message = self.linear[idx](original_message) / (node_local_density + 1)
_message = self.reshape[idx](_message).unsqueeze(-1)
message = torch.cat((message, _message), dim = -1)
return (
message,
sc,
)
elif self.tensor_format in ["flexible_non_symmetric_tucker"]:
message = self.linear[0](original_message / (node_local_density + 1)) # [n_nodes, klm]
message = message.unsqueeze(-1) # [n_nnodes, klm, 1]
for idx in range(1, self.correlation):
_message = self.linear[idx](original_message) / (node_local_density + 1)
_message = _message.unsqueeze(-1)
message = torch.cat((message, _message), dim = -1)
return (
message,
sc,
)

@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion mace/modules/irreps_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def make_tucker_irreps(target_irreps, correlation):
tp_irreps += o3.Irreps(f"{num_feats}x{tmp_irreps[0].ir}")
return tp_irreps

def make_tucker_irreps_flexible(target_irreps, correlation):
def make_tucker_irreps_flexible(input_target_irreps, correlation):
# make sure it is not a string but a o3.Irreps
target_irreps = o3.Irreps(input_target_irreps)
tp_irreps = o3.Irreps()
for ir in target_irreps:
tmp_irreps = o3.Irreps(str(ir))
Expand Down
12 changes: 6 additions & 6 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
correlation=correlation[0],
radial_MLP=radial_MLP,
#
tensor_format=tensor_format,
tensor_format=tensor_format[0],
)
self.interactions = torch.nn.ModuleList([inter])

Expand All @@ -142,7 +142,7 @@ def __init__(
use_sc=use_sc_first,
learned_radials_dim=inter.conv_tp.weight_numel,
#
tensor_format=tensor_format,
tensor_format=tensor_format[0],
flexible_feats_L=flexible_feats_L,
gaussian_prior=gaussian_prior,
)
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
radial_MLP=radial_MLP,
correlation=correlation[i + 1],
#
#tensor_format=tensor_format,
tensor_format=tensor_format[i + 1],
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
Expand All @@ -186,9 +186,9 @@ def __init__(
use_sc=True,
learned_radials_dim=inter.conv_tp.weight_numel,
##
# tensor_format=tensor_format,
# flexible_feats_L=flexible_feats_L,
# gaussian_prior=gaussian_prior,
tensor_format=tensor_format[i + 1],
flexible_feats_L=flexible_feats_L,
gaussian_prior=gaussian_prior,
# learned_radials_dim=inter.conv_tp.weight_numel
)
self.products.append(prod)
Expand Down
21 changes: 20 additions & 1 deletion mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mace.tools.compile import simplify_if_compile
from mace.tools.scatter import scatter_sum

import math

@compile_mode("script")
class BesselBasis(torch.nn.Module):
Expand Down Expand Up @@ -55,6 +55,25 @@ def __repr__(self):
f"trainable={self.bessel_weights.requires_grad})"
)

def continuous_sinous_embedding(densities, dim, max_density=100):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_density) * torch.arange(start=0, end=half) / half
).to(device=densities.device)
args = densities[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding

@compile_mode("script")
class ChebychevBasis(torch.nn.Module):
Expand Down
Loading

0 comments on commit 63f813c

Please sign in to comment.