From 4f89aa8f8f7124d20e5e43bf4f20f188bead6340 Mon Sep 17 00:00:00 2001 From: EnricoTrizio Date: Thu, 5 Dec 2024 13:12:37 +0100 Subject: [PATCH] Added two model options, see also commit 4b45a9ded1f7da46cc1a0eac091afee340230929 --- mlcolvar/core/nn/graph/schnet.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mlcolvar/core/nn/graph/schnet.py b/mlcolvar/core/nn/graph/schnet.py index 4a08cbc..f34f363 100644 --- a/mlcolvar/core/nn/graph/schnet.py +++ b/mlcolvar/core/nn/graph/schnet.py @@ -37,6 +37,10 @@ class SchNetModel(BaseGNN): Number of filters. n_hidden_channels: int Size of hidden embeddings. + aggr: str + Type of the GNN aggr function. + w_out_after_sum: bool + If apply the readout MLP layer after the scatter sum. References ---------- .. [1] Schütt, Kristof T., et al. "Schnet–a deep learning architecture for @@ -53,7 +57,9 @@ def __init__( n_layers: int = 2, n_filters: int = 16, n_hidden_channels: int = 16, - drop_rate: int = 0 + drop_rate: int = 0, + aggr: str = 'mean', + w_out_after_sum: bool = False ) -> None: super().__init__( @@ -66,7 +72,7 @@ def __init__( self.layers = nn.ModuleList([ InteractionBlock( - n_hidden_channels, n_bases, n_filters, cutoff + n_hidden_channels, n_bases, n_filters, cutoff, aggr ) for _ in range(n_layers) ]) @@ -77,6 +83,8 @@ def __init__( nn.Linear(n_hidden_channels // 2, n_out) ]) + self._w_out_after_sum = w_out_after_sum + self.reset_parameters() def reset_parameters(self) -> None: @@ -115,8 +123,9 @@ def forward( for layer in self.layers: h_V = h_V + layer(h_V, data['edge_index'], h_E[0], h_E[1]) - for w in self.W_out: - h_V = w(h_V) + if not self._w_out_after_sum: + for w in self.W_out: + h_V = w(h_V) out = h_V if scatter_mean: @@ -128,6 +137,10 @@ def forward( # TODO check this is equivalent in torch scatter out = torch_scatter.scatter_sum(out, batch_id, dim=0) out = out / data['n_system'] + + if self._w_out_after_sum: + for w in self.W_out: + out = w(out) return out