From 38377f8911264f95c7ebbb4bf8f0442c36ad793f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 19 Sep 2024 08:58:41 +0100 Subject: [PATCH 1/3] test density normalization --- mace/modules/__init__.py | 4 + mace/modules/blocks.py | 192 +++++++++++++++++++++++++++++++++++++++ mace/tools/arg_parser.py | 4 + 3 files changed, 200 insertions(+) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..69e102b5 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -18,6 +18,8 @@ RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, ScaleShiftBlock, ) from .loss import ( @@ -56,6 +58,8 @@ "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 34539b0b..2bd20f74 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -664,6 +664,198 @@ def forward( ) # [n_nodes, channels, (lmax + 1)**2] +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + num_scalar_node_features = self.node_feats_irreps[0].mul + self.node_scalar_linear = torch.nn.Linear( + num_scalar_node_features, self.conv_tp.weight_numel + ) + + self.reshape = reshape_irreps(self.irreps_out) + self.density_fn = nn.FullyConnectedNet( + [self.conv_tp.weight_numel] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + node_feats_scalar = self.node_scalar_linear( + node_feats[:, self.node_feats_irreps.slices()[0]] + ) + edge_density = torch.tanh( + self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 + ) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + num_scalar_node_features = self.node_feats_irreps[0].mul + self.node_scalar_linear = torch.nn.Linear( + num_scalar_node_features, self.conv_tp.weight_numel + ) + + self.reshape = reshape_irreps(self.irreps_out) + self.density_fn = nn.FullyConnectedNet( + [self.conv_tp.weight_numel] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + node_feats_scalar = self.node_scalar_linear( + node_feats[:, self.node_feats_irreps.slices()[0]] + ) + edge_density = torch.tanh( + self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 + ) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 046f04d6..8fa8c0ac 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -153,6 +153,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticAttResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -163,6 +165,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=[ "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( From 294f90cfced3e9a518c1584ba867cc0ceb018cb3 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:23:56 +0100 Subject: [PATCH 2/3] simplify the density normalization --- mace/modules/blocks.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 2bd20f74..0db3b02e 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -710,19 +710,15 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Density normalization - num_scalar_node_features = self.node_feats_irreps[0].mul - self.node_scalar_linear = torch.nn.Linear( - num_scalar_node_features, self.conv_tp.weight_numel - ) - - self.reshape = reshape_irreps(self.irreps_out) self.density_fn = nn.FullyConnectedNet( - [self.conv_tp.weight_numel] + [input_dim] + [ 1, ], torch.nn.functional.silu, ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) def forward( self, @@ -737,12 +733,7 @@ def forward( num_nodes = node_feats.shape[0] node_feats = self.linear_up(node_feats) tp_weights = self.conv_tp_weights(edge_feats) - node_feats_scalar = self.node_scalar_linear( - node_feats[:, self.node_feats_irreps.slices()[0]] - ) - edge_density = torch.tanh( - self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 - ) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] @@ -806,20 +797,17 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Density normalization - num_scalar_node_features = self.node_feats_irreps[0].mul - self.node_scalar_linear = torch.nn.Linear( - num_scalar_node_features, self.conv_tp.weight_numel - ) - - self.reshape = reshape_irreps(self.irreps_out) self.density_fn = nn.FullyConnectedNet( - [self.conv_tp.weight_numel] + [input_dim] + [ 1, ], torch.nn.functional.silu, ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + def forward( self, node_attrs: torch.Tensor, @@ -834,12 +822,7 @@ def forward( sc = self.skip_tp(node_feats, node_attrs) node_feats = self.linear_up(node_feats) tp_weights = self.conv_tp_weights(edge_feats) - node_feats_scalar = self.node_scalar_linear( - node_feats[:, self.node_feats_irreps.slices()[0]] - ) - edge_density = torch.tanh( - self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2 - ) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) mji = self.conv_tp( node_feats[sender], edge_attrs, tp_weights ) # [n_edges, irreps] From 165993c4db404644e574575fc69490f56fad473b Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:53:23 +0000 Subject: [PATCH 3/3] fixing import order --- mace/modules/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 69e102b5..e48e0b23 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -15,11 +15,11 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, - RealAgnosticDensityResidualInteractionBlock, - RealAgnosticDensityInteractionBlock, ScaleShiftBlock, ) from .loss import (