From 63f813cdbc38b841ccfa167c17e4db29daf1f0b1 Mon Sep 17 00:00:00 2001 From: CheukHinHoJerry Date: Fri, 22 Nov 2024 01:56:19 -0800 Subject: [PATCH] [Verified] remove gaussian prior, add DN sinu freq embedding and support specifying tensor format for each layer --- mace/modules/blocks.py | 210 +++++++++++++++++++++++--- mace/modules/irreps_tools.py | 4 +- mace/modules/models.py | 12 +- mace/modules/radial.py | 21 ++- mace/modules/symmetric_contraction.py | 49 +++--- mace/tools/arg_parser.py | 28 ++-- mace/tools/arg_parser_tools.py | 7 +- 7 files changed, 268 insertions(+), 63 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index fc3398a0..df7af412 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -32,9 +32,14 @@ GaussianBasis, PolynomialCutoff, SoftTransform, + continuous_sinous_embedding, ) from .symmetric_contraction import SymmetricContraction +from functools import partial + +AGNOSTIC = False + @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): @@ -228,26 +233,26 @@ def tensor_power_einsum(tensor, N): result = result.reshape(batch_size, dim ** N, features) return result -# @compile_mode("script") -# class TensorFormatBlock(torch.nn.Module): -# """ -# Maybe useful for reshaping tensor for efficient operation later. -# """ -# def __init__(self, tensor_format, correlation): -# super().__init__() +@compile_mode("script") +class TensorFormatBlock(torch.nn.Module): + """ + Maybe useful for reshaping tensor for efficient operation later. + """ + def __init__(self, tensor_format, correlation): + super().__init__() -# self.tensor_format = tensor_format -# self.correlation = correlation -# #self.irreps_in = irreps_in -# #self.indices = [chr(ord('a') + i) for i in range(N)] -# #self.eq = ','.join(['bi' + 'f' for _ in range(N)]) + '->b' + ''.join(indices) + 'f' + self.tensor_format = tensor_format + self.correlation = correlation + #self.irreps_in = irreps_in + #self.indices = [chr(ord('a') + i) for i in range(N)] + #self.eq = ','.join(['bi' + 'f' for _ in range(N)]) + '->b' + ''.join(indices) + 'f' -# def forward(self, message) -> torch.Tensor: -# batch_size, dim, features = message.shape -# if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]: -# return message -# elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]: -# return message + def forward(self, message) -> torch.Tensor: + batch_size, dim, features = message.shape + if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]: + return message + elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]: + return message @compile_mode("script") @@ -260,6 +265,7 @@ def __init__( learned_radials_dim: int, use_sc: bool = True, num_elements: Optional[int] = None, + agnostic: Optional[bool] = False, tensor_format = "symmetric_cp", flexible_feats_L = False, gaussian_prior = False, @@ -272,6 +278,7 @@ def __init__( irreps_out=target_irreps, correlation=correlation, num_elements=num_elements, + agnostic=agnostic, tensor_format=tensor_format, flexible_feats_L=flexible_feats_L, gaussian_prior=gaussian_prior, @@ -280,6 +287,7 @@ def __init__( if tensor_format in ["symmetric_cp", "non_symmetric_cp"]: mid_irreps = target_irreps elif tensor_format in ["flexible_non_symmetric_tucker", "flexible_symmetric_tucker",]: + print(target_irreps, correlation) mid_irreps = make_tucker_irreps_flexible(target_irreps, correlation) elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]: mid_irreps = make_tucker_irreps(target_irreps, correlation) @@ -301,8 +309,7 @@ def forward( node_feats = self.symmetric_contractions(node_feats, node_attrs) if self.use_sc and sc is not None: return self.linear(node_feats) + sc - return self.linear(node_feats) - + return self.linear(node_feats) @compile_mode("script") class InteractionBlock(torch.nn.Module): @@ -316,7 +323,9 @@ def __init__( hidden_irreps: o3.Irreps, avg_num_neighbors: float, correlation: int, + gate: Optional[Callable] = torch.nn.functional.silu, radial_MLP: Optional[List[int]] = None, + agnostic: Optional[bool] = False, tensor_format: str = "symmetric_cp", ) -> None: super().__init__() @@ -330,6 +339,9 @@ def __init__( if radial_MLP is None: radial_MLP = [64, 64, 64] self.radial_MLP = radial_MLP + self.gate = gate + self.agnostic = agnostic + # === tensor format stuffs === self.tensor_format = tensor_format self.correlation = correlation @@ -740,7 +752,7 @@ def forward( if self.tensor_format in ["symmetric_cp", "symmetric_tucker",]: message = self.linear(original_message) / self.avg_num_neighbors return ( - self.tensor_format_layer(self.reshape(message)), + self.reshape(message), sc, ) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2] elif self.tensor_format in ["flexible_symmetric_tucker"]: @@ -991,6 +1003,162 @@ def forward( ) +@compile_mode("script") +class RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + print(f"RealAgnosticInteractionGateBlock --> {self.gate}") + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + self.gate, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + + if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]: + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker", "flexible_non_symmetric_tucker"]: + self.linear = torch.nn.ModuleList([]) + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = torch.nn.ModuleList([]) + for _ in range(self.correlation): + self.linear.append(o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + )) + self.reshape.append(reshape_irreps(self.irreps_out)) + + if not getattr(self, "agnostic", False): + ValueError("agnostic not supported yet inRealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock") + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + else: + ## Selector TensorProduct + #self.skip_tp = o3.FullyConnectedTensorProduct( + # self.irreps_out, self.node_feats_irreps, self.irreps_out + #) + pass + + self.density_fn = nn.FullyConnectedNet( + [input_dim] + [1,], + self.gate + ) + + self.sinous_embedding = partial(continuous_sinous_embedding, dim=32, max_density=100) + self.density_linear = torch.nn.Linear(32, self.irreps_out[0].mul, bias=False) # TODO: density embedding model + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + # learnable density funciton with + density = torch.tanh(self.density_fn(edge_feats)**2) + + # NO RESCALE + #mji = mji * density + + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + + node_local_density = scatter_sum( + src=density, index=receiver, dim=0, dim_size=num_nodes + ) + + message = message / (node_local_density + 1) + + # density_embedding + sin_embedding = self.sinous_embedding(node_local_density.flatten()) + density_embedding = self.density_linear(sin_embedding) + # density inject + message[:, self.irreps_out.slices()[0]] += density_embedding + + # == tensor formats === + original_message = message + if self.tensor_format in ["symmetric_cp", "symmetric_tucker",]: + message = self.linear(original_message) / (node_local_density + 1) + return ( + self.reshape(message), + sc, + ) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2] + elif self.tensor_format in ["flexible_symmetric_tucker"]: + message = self.linear(original_message) / (node_local_density + 1) + # requires format contraction in SymmetricContraction - no reshape + # to [n_nodes, channels, (lmax + 1) ** 2 ] yet + return (message, sc) + elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]: + message = self.reshape[0](self.linear[0](original_message) / (node_local_density + 1)) + message = message.unsqueeze(-1) + for idx in range(1, self.correlation): + _message = self.linear[idx](original_message) / (node_local_density + 1) + _message = self.reshape[idx](_message).unsqueeze(-1) + message = torch.cat((message, _message), dim = -1) + return ( + message, + sc, + ) + elif self.tensor_format in ["flexible_non_symmetric_tucker"]: + message = self.linear[0](original_message / (node_local_density + 1)) # [n_nodes, klm] + message = message.unsqueeze(-1) # [n_nnodes, klm, 1] + for idx in range(1, self.correlation): + _message = self.linear[idx](original_message) / (node_local_density + 1) + _message = _message.unsqueeze(-1) + message = torch.cat((message, _message), dim = -1) + return ( + message, + sc, + ) + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 3bcc8583..e914b13d 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -31,7 +31,9 @@ def make_tucker_irreps(target_irreps, correlation): tp_irreps += o3.Irreps(f"{num_feats}x{tmp_irreps[0].ir}") return tp_irreps -def make_tucker_irreps_flexible(target_irreps, correlation): +def make_tucker_irreps_flexible(input_target_irreps, correlation): + # make sure it is not a string but a o3.Irreps + target_irreps = o3.Irreps(input_target_irreps) tp_irreps = o3.Irreps() for ir in target_irreps: tmp_irreps = o3.Irreps(str(ir)) diff --git a/mace/modules/models.py b/mace/modules/models.py index 1e2ba5e6..a8ebeeca 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -124,7 +124,7 @@ def __init__( correlation=correlation[0], radial_MLP=radial_MLP, # - tensor_format=tensor_format, + tensor_format=tensor_format[0], ) self.interactions = torch.nn.ModuleList([inter]) @@ -142,7 +142,7 @@ def __init__( use_sc=use_sc_first, learned_radials_dim=inter.conv_tp.weight_numel, # - tensor_format=tensor_format, + tensor_format=tensor_format[0], flexible_feats_L=flexible_feats_L, gaussian_prior=gaussian_prior, ) @@ -175,7 +175,7 @@ def __init__( radial_MLP=radial_MLP, correlation=correlation[i + 1], # - #tensor_format=tensor_format, + tensor_format=tensor_format[i + 1], ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( @@ -186,9 +186,9 @@ def __init__( use_sc=True, learned_radials_dim=inter.conv_tp.weight_numel, ## - # tensor_format=tensor_format, - # flexible_feats_L=flexible_feats_L, - # gaussian_prior=gaussian_prior, + tensor_format=tensor_format[i + 1], + flexible_feats_L=flexible_feats_L, + gaussian_prior=gaussian_prior, # learned_radials_dim=inter.conv_tp.weight_numel ) self.products.append(prod) diff --git a/mace/modules/radial.py b/mace/modules/radial.py index a928c184..39a4e783 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -11,7 +11,7 @@ from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum - +import math @compile_mode("script") class BesselBasis(torch.nn.Module): @@ -55,6 +55,25 @@ def __repr__(self): f"trainable={self.bessel_weights.requires_grad})" ) +def continuous_sinous_embedding(densities, dim, max_density=100): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_density) * torch.arange(start=0, end=half) / half + ).to(device=densities.device) + args = densities[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding @compile_mode("script") class ChebychevBasis(torch.nn.Module): diff --git a/mace/modules/symmetric_contraction.py b/mace/modules/symmetric_contraction.py index 7c48ed0e..950dc212 100644 --- a/mace/modules/symmetric_contraction.py +++ b/mace/modules/symmetric_contraction.py @@ -34,13 +34,13 @@ def __init__( internal_weights: Optional[bool] = None, shared_weights: Optional[bool] = None, num_elements: Optional[int] = None, + agnostic: Optional[bool] = False, tensor_format: str = "symmetric_cp", flexible_feats_L: bool = False, gaussian_prior: bool = False, ) -> None: super().__init__() - self.gaussian_prior = gaussian_prior if irrep_normalization is None: @@ -349,24 +349,29 @@ def __init__( self.weights = weights[:-1] self.weights_max = weights[-1] - # add gaussian prior - if gaussian_prior: - # given by exp(-wK * ( Rk_fn(Rk) )^2 - wL * sum_t l^2_t) - # For CP, this can be done by two steps: - # - multiply A_klm alog lm channel by exp(-wL * l^2_t) during the t th contraction - # - multiply the final feature along k channel by exp(-wK * ( Rk_fn(Rk) )^2 ) - # normalization weighting over the lms - self.wL = torch.nn.Parameter(torch.tensor(1.5, requires_grad = True)) - # normalization weighting over radials - #self.wK = torch.nn.Parameter(torch.tensor(1.5, requires_grad = True)) - # # learnable smoothness of learned radial - # self.Rk_fn = nn.FullyConnectedNet( - # [self.num_features] - # + [ - # self.num_features, - # ], - # torch.nn.functional.silu, - # ) + # # add gaussian prior + # if gaussian_prior: + # # given by exp(-wK * ( Rk_fn(Rk) )^2 - wL * sum_t l^2_t) + # # For CP, this can be done by two steps: + # # - multiply A_klm alog lm channel by exp(-wL * l^2_t) during the t th contraction + # # - multiply the final feature along k channel by exp(-wK * ( Rk_fn(Rk) )^2 ) + # # normalization weighting over the lms + # self.wL = torch.nn.Parameter(torch.tensor(1.5, requires_grad = True)) + # # smoothness prior for normalization + # self.register_buffer("l2vec", torch.tensor([irrep.l for _, irrep in self.coupling_irreps for _ in range((2 * irrep.l + 1))])) + # # normalization weighting over radials + # #self.wK = torch.nn.Parameter(torch.tensor(1.5, requires_grad = True)) + # # # learnable smoothness of learned radial + # # self.Rk_fn = nn.FullyConnectedNet( + # # [self.num_features] + # # + [ + # # self.num_features, + # # ], + # # torch.nn.functional.silu, + # # ) + # else: + # self.wL = 0 + # self.register_buffer("l2vec", torch.zeros_like(torch.tensor([irrep.l for _, irrep in self.coupling_irreps for _ in range((2 * irrep.l + 1))]))) def forward(self, x: torch.Tensor, y: torch.Tensor): irrep_out = self.irrep_out @@ -505,13 +510,15 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): out = self.graph_opt_main( self.U_tensors(self.correlation), self.weights_max, # shape = (num_elements, num_paras, channel) - x, # [nnodes, channel, num_paras] + #x * torch.exp(-self.wL * (self.l2vec ** 2) ) if self.gaussian_prior else x, # [nnodes, channel, num_paras] + x, y, ) elif self.tensor_format == "non_symmetric_cp": out = self.graph_opt_main( self.U_tensors(self.correlation), self.weights_max, # shape = (num_elements, num_paras, channel) + #x[:, :, :, self.correlation - 1] * torch.exp(-self.wL * (self.l2vec ** 2) ) if self.gaussian_prior else x[:, :, :, self.correlation - 1], x[:, :, :, self.correlation - 1], y, ) @@ -525,8 +532,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ) c_tensor = c_tensor + out if self.tensor_format == "symmetric_cp": + #out = contract_features(c_tensor, x * torch.exp(-self.wL * (self.l2vec ** 2) if self.gaussian_prior else x)) out = contract_features(c_tensor, x) elif self.tensor_format == "non_symmetric_cp": + #out = contract_features(c_tensor, x[:, :, :, self.correlation - i - 1] * torch.exp(-self.wL * (self.l2vec ** 2) if self.gaussian_prior else x[:, :, :, self.correlation - i - 1] )) out = contract_features(c_tensor, x[:, :, :, self.correlation - i - 1]) return out.reshape(out.shape[0], -1) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index bf4351a5..9bf348ec 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -155,6 +155,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticInteractionBlock", "RealAgnosticDensityResidualInteractionBlock", "RealAgnosticDensityInteractionBlock" + "RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock", ], ) parser.add_argument( @@ -166,7 +167,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", "RealAgnosticDensityResidualInteractionBlock", - "RealAgnosticDensityInteractionBlock" + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock" ], ) parser.add_argument( @@ -249,17 +251,19 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--tensor_format", - help="Tensor format being used", - type=str, - default="symmetric_cp", - choices=["symmetric_cp", - "non_symmetric_cp", - "symmetric_tucker", - "non_symmetric_tucker", - "flexible_symmetric_tucker", - "flexible_non_symmetric_tucker" - ] - ) + help="Tensor format(s) being used for each layer", + nargs="+", + default=["symmetric_cp", "symmetric_cp"], + # choices=[ + # "symmetric_cp", + # "non_symmetric_cp", + # "symmetric_tucker", + # "non_symmetric_tucker", + # "flexible_symmetric_tucker", + # "flexible_non_symmetric_tucker" + # ], + ) + parser.add_argument( "--flexible_feats_L", help="Allowing different number of channels for different L", diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py index 3116d17a..15ee5ca1 100644 --- a/mace/tools/arg_parser_tools.py +++ b/mace/tools/arg_parser_tools.py @@ -61,12 +61,12 @@ def check_args(args): .sort() .irreps.simplify() ) - if args.tensor_format in ["symmetric_cp", "symmetric_tucker", "non_symmetric_cp", "non_symmetric_tucker"] or args.flexible_feats_L: + if any([tf in ["symmetric_cp", "symmetric_tucker", "non_symmetric_cp", "non_symmetric_tucker"] for tf in args.tensor_format]) or args.flexible_feats_L: assert ( len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" elif args.hidden_irreps is not None: - if args.tensor_format in ["symmetric_cp", "symmetric_tucker", "non_symmetric_cp", "non_symmetric_tucker"] or args.flexible_feats_L: + if any([tf in ["symmetric_cp", "symmetric_tucker", "non_symmetric_cp", "non_symmetric_tucker"] for tf in args.tensor_format]) or args.flexible_feats_L: assert ( len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" @@ -92,6 +92,9 @@ def check_args(args): .irreps.simplify() ) + # tensor format checks + assert args.num_interactions == len(args.tensor_format), "Number of tensor format defined must be equal to number of interactions" + # Loss and optimization # Check Stage Two loss start if args.swa: