Skip to content

Commit

Permalink
Added two model options, see also commit 4b45a9d
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Dec 5, 2024
1 parent dc650a3 commit 4f89aa8
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions mlcolvar/core/nn/graph/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class SchNetModel(BaseGNN):
Number of filters.
n_hidden_channels: int
Size of hidden embeddings.
aggr: str
Type of the GNN aggr function.
w_out_after_sum: bool
If apply the readout MLP layer after the scatter sum.
References
----------
.. [1] Schütt, Kristof T., et al. "Schnet–a deep learning architecture for
Expand All @@ -53,7 +57,9 @@ def __init__(
n_layers: int = 2,
n_filters: int = 16,
n_hidden_channels: int = 16,
drop_rate: int = 0
drop_rate: int = 0,
aggr: str = 'mean',
w_out_after_sum: bool = False
) -> None:

super().__init__(
Expand All @@ -66,7 +72,7 @@ def __init__(

self.layers = nn.ModuleList([
InteractionBlock(
n_hidden_channels, n_bases, n_filters, cutoff
n_hidden_channels, n_bases, n_filters, cutoff, aggr
)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a class instantiation Error

Call to
InteractionBlock.__init__
with too many arguments; should be no more than 4.
for _ in range(n_layers)
])
Expand All @@ -77,6 +83,8 @@ def __init__(
nn.Linear(n_hidden_channels // 2, n_out)
])

self._w_out_after_sum = w_out_after_sum

self.reset_parameters()

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -115,8 +123,9 @@ def forward(
for layer in self.layers:
h_V = h_V + layer(h_V, data['edge_index'], h_E[0], h_E[1])

for w in self.W_out:
h_V = w(h_V)
if not self._w_out_after_sum:
for w in self.W_out:
h_V = w(h_V)
out = h_V

if scatter_mean:
Expand All @@ -128,6 +137,10 @@ def forward(
# TODO check this is equivalent in torch scatter
out = torch_scatter.scatter_sum(out, batch_id, dim=0)
out = out / data['n_system']

if self._w_out_after_sum:
for w in self.W_out:
out = w(out)

return out

Expand Down

0 comments on commit 4f89aa8

Please sign in to comment.