Skip to content

Commit

Permalink
simplify the density normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Oct 7, 2024
1 parent 38377f8 commit 294f90c
Showing 1 changed file with 9 additions and 26 deletions.
35 changes: 9 additions & 26 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,19 +710,15 @@ def _setup(self) -> None:
self.reshape = reshape_irreps(self.irreps_out)

# Density normalization
num_scalar_node_features = self.node_feats_irreps[0].mul
self.node_scalar_linear = torch.nn.Linear(
num_scalar_node_features, self.conv_tp.weight_numel
)

self.reshape = reshape_irreps(self.irreps_out)
self.density_fn = nn.FullyConnectedNet(
[self.conv_tp.weight_numel]
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out)

def forward(
self,
Expand All @@ -737,12 +733,7 @@ def forward(
num_nodes = node_feats.shape[0]
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
node_feats_scalar = self.node_scalar_linear(
node_feats[:, self.node_feats_irreps.slices()[0]]
)
edge_density = torch.tanh(
self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2
)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
Expand Down Expand Up @@ -806,20 +797,17 @@ def _setup(self) -> None:
self.reshape = reshape_irreps(self.irreps_out)

# Density normalization
num_scalar_node_features = self.node_feats_irreps[0].mul
self.node_scalar_linear = torch.nn.Linear(
num_scalar_node_features, self.conv_tp.weight_numel
)

self.reshape = reshape_irreps(self.irreps_out)
self.density_fn = nn.FullyConnectedNet(
[self.conv_tp.weight_numel]
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)

# Reshape
self.reshape = reshape_irreps(self.irreps_out)

def forward(
self,
node_attrs: torch.Tensor,
Expand All @@ -834,12 +822,7 @@ def forward(
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
node_feats_scalar = self.node_scalar_linear(
node_feats[:, self.node_feats_irreps.slices()[0]]
)
edge_density = torch.tanh(
self.density_fn(tp_weights * node_feats_scalar[sender]) ** 2
)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
Expand Down

0 comments on commit 294f90c

Please sign in to comment.