From 9314aabd1f284b7fc5549d9a3e2801acd4c2e6c6 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 27 Aug 2019 22:29:25 +0800 Subject: [PATCH] [Refactor] Interface of nn modules (#798) * refactor * upd mpnn --- examples/mxnet/gcn/gcn.py | 2 +- examples/pytorch/appnp/appnp.py | 2 +- examples/pytorch/gat/gat.py | 4 +- examples/pytorch/gcn/gcn.py | 2 +- examples/pytorch/gin/gin.py | 2 +- examples/pytorch/graphsage/graphsage.py | 2 +- .../model_zoo/citation_network/conf.py | 2 +- .../model_zoo/citation_network/models.py | 20 +-- examples/pytorch/sgc/sgc.py | 4 +- examples/pytorch/sgc/sgc_reddit.py | 4 +- examples/pytorch/tagcn/tagcn.py | 2 +- python/dgl/model_zoo/chem/mpnn.py | 2 +- python/dgl/nn/mxnet/conv.py | 6 +- python/dgl/nn/mxnet/glob.py | 36 +++--- python/dgl/nn/pytorch/conv.py | 116 ++++++++++-------- python/dgl/nn/pytorch/glob.py | 48 ++++---- tests/mxnet/test_nn.py | 38 +++--- tests/pytorch/test_nn.py | 86 ++++++------- 18 files changed, 194 insertions(+), 184 deletions(-) diff --git a/examples/mxnet/gcn/gcn.py b/examples/mxnet/gcn/gcn.py index 2e1de3114046..bf60bf090303 100644 --- a/examples/mxnet/gcn/gcn.py +++ b/examples/mxnet/gcn/gcn.py @@ -36,5 +36,5 @@ def forward(self, features): for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) - h = layer(h, self.g) + h = layer(self.g, h) return h diff --git a/examples/pytorch/appnp/appnp.py b/examples/pytorch/appnp/appnp.py index 5a055789255a..c7ee38116f9d 100644 --- a/examples/pytorch/appnp/appnp.py +++ b/examples/pytorch/appnp/appnp.py @@ -53,5 +53,5 @@ def forward(self, features): h = self.activation(layer(h)) h = self.layers[-1](self.feat_drop(h)) # propagation step - h = self.propagate(h, self.g) + h = self.propagate(self.g, h) return h diff --git a/examples/pytorch/gat/gat.py b/examples/pytorch/gat/gat.py index 2103961bef24..88a06c56880f 100644 --- a/examples/pytorch/gat/gat.py +++ b/examples/pytorch/gat/gat.py @@ -49,7 +49,7 @@ def __init__(self, def forward(self, inputs): h = inputs for l in range(self.num_layers): - h = self.gat_layers[l](h, self.g).flatten(1) + h = self.gat_layers[l](self.g, h).flatten(1) # output projection - logits = self.gat_layers[-1](h, self.g).mean(1) + logits = self.gat_layers[-1](self.g, h).mean(1) return logits diff --git a/examples/pytorch/gcn/gcn.py b/examples/pytorch/gcn/gcn.py index 1b142fc20b7e..a4c0e1609449 100644 --- a/examples/pytorch/gcn/gcn.py +++ b/examples/pytorch/gcn/gcn.py @@ -35,5 +35,5 @@ def forward(self, features): for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) - h = layer(h, self.g) + h = layer(self.g, h) return h diff --git a/examples/pytorch/gin/gin.py b/examples/pytorch/gin/gin.py index 1260020675de..b73e7146a46f 100644 --- a/examples/pytorch/gin/gin.py +++ b/examples/pytorch/gin/gin.py @@ -155,7 +155,7 @@ def forward(self, g): hidden_rep = [h] for layer in range(self.num_layers - 1): - h = self.ginlayers[layer](h, g) + h = self.ginlayers[layer](g, h) hidden_rep.append(h) score_over_layer = 0 diff --git a/examples/pytorch/graphsage/graphsage.py b/examples/pytorch/graphsage/graphsage.py index dbaabeb35dba..1a4a22eea78b 100644 --- a/examples/pytorch/graphsage/graphsage.py +++ b/examples/pytorch/graphsage/graphsage.py @@ -41,7 +41,7 @@ def __init__(self, def forward(self, features): h = features for layer in self.layers: - h = layer(h, self.g) + h = layer(self.g, h) return h diff --git a/examples/pytorch/model_zoo/citation_network/conf.py b/examples/pytorch/model_zoo/citation_network/conf.py index 57297b86d419..05f53c2ed58f 100644 --- a/examples/pytorch/model_zoo/citation_network/conf.py +++ b/examples/pytorch/model_zoo/citation_network/conf.py @@ -50,7 +50,7 @@ } CHEBNET_CONFIG = { - 'extra_args': [16, 1, 3, True], + 'extra_args': [32, 1, 2, True], 'lr': 1e-2, 'weight_decay': 5e-4, } diff --git a/examples/pytorch/model_zoo/citation_network/models.py b/examples/pytorch/model_zoo/citation_network/models.py index 1ec7de10a09e..176b61a6602f 100644 --- a/examples/pytorch/model_zoo/citation_network/models.py +++ b/examples/pytorch/model_zoo/citation_network/models.py @@ -30,7 +30,7 @@ def forward(self, features): for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) - h = layer(h, self.g) + h = layer(self.g, h) return h @@ -70,9 +70,9 @@ def __init__(self, def forward(self, inputs): h = inputs for l in range(self.num_layers): - h = self.gat_layers[l](h, self.g).flatten(1) + h = self.gat_layers[l](self.g, h).flatten(1) # output projection - logits = self.gat_layers[-1](h, self.g).mean(1) + logits = self.gat_layers[-1](self.g, h).mean(1) return logits @@ -101,7 +101,7 @@ def __init__(self, def forward(self, features): h = features for layer in self.layers: - h = layer(h, self.g) + h = layer(self.g, h) return h @@ -148,7 +148,7 @@ def forward(self, features): h = self.activation(layer(h)) h = self.layers[-1](self.feat_drop(h)) # propagation step - h = self.propagate(h, self.g) + h = self.propagate(self.g, h) return h @@ -178,7 +178,7 @@ def forward(self, features): for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) - h = layer(h, self.g) + h = layer(self.g, h) return h @@ -210,7 +210,7 @@ def __init__(self, def forward(self, features): h = self.proj(features) for layer in self.layers: - h = layer(h, self.g) + h = layer(self.g, h) return self.cls(h) @@ -231,7 +231,7 @@ def __init__(self, bias=bias) def forward(self, features): - return self.net(features, self.g) + return self.net(self.g, features) class GIN(nn.Module): @@ -286,7 +286,7 @@ def __init__(self, def forward(self, features): h = features for layer in self.layers: - h = layer(h, self.g) + h = layer(self.g, h) return h class ChebNet(nn.Module): @@ -316,5 +316,5 @@ def __init__(self, def forward(self, features): h = features for layer in self.layers: - h = layer(h, self.g) + h = layer(self.g, h, [2]) return h \ No newline at end of file diff --git a/examples/pytorch/sgc/sgc.py b/examples/pytorch/sgc/sgc.py index c793e32bb9d5..f799269fe588 100644 --- a/examples/pytorch/sgc/sgc.py +++ b/examples/pytorch/sgc/sgc.py @@ -19,7 +19,7 @@ def evaluate(model, g, features, labels, mask): model.eval() with torch.no_grad(): - logits = model(features, g)[mask] # only compute the evaluation set + logits = model(g, features)[mask] # only compute the evaluation set labels = labels[mask] _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) @@ -86,7 +86,7 @@ def main(args): if epoch >= 3: t0 = time.time() # forward - logits = model(features, g) # only compute the train set + logits = model(g, features) # only compute the train set loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() diff --git a/examples/pytorch/sgc/sgc_reddit.py b/examples/pytorch/sgc/sgc_reddit.py index 2f458e6a9b03..f1a25e438d48 100644 --- a/examples/pytorch/sgc/sgc_reddit.py +++ b/examples/pytorch/sgc/sgc_reddit.py @@ -21,7 +21,7 @@ def normalize(h): def evaluate(model, features, graph, labels, mask): model.eval() with torch.no_grad(): - logits = model(features, graph)[mask] # only compute the evaluation set + logits = model(graph, features)[mask] # only compute the evaluation set labels = labels[mask] _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) @@ -82,7 +82,7 @@ def main(args): # define loss closure def closure(): optimizer.zero_grad() - output = model(features, g)[train_mask] + output = model(g, features)[train_mask] loss_train = F.cross_entropy(output, labels[train_mask]) loss_train.backward() return loss_train diff --git a/examples/pytorch/tagcn/tagcn.py b/examples/pytorch/tagcn/tagcn.py index 2bfcff77a251..804b91daee1e 100644 --- a/examples/pytorch/tagcn/tagcn.py +++ b/examples/pytorch/tagcn/tagcn.py @@ -35,5 +35,5 @@ def forward(self, features): for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) - h = layer(h, self.g) + h = layer(self.g, h) return h diff --git a/python/dgl/model_zoo/chem/mpnn.py b/python/dgl/model_zoo/chem/mpnn.py index a650b339b916..3d64baf1c76d 100644 --- a/python/dgl/model_zoo/chem/mpnn.py +++ b/python/dgl/model_zoo/chem/mpnn.py @@ -145,7 +145,7 @@ def forward(self, g): out, h = self.gru(m.unsqueeze(0), h) out = out.squeeze(0) - out = self.set2set(out, g) + out = self.set2set(g, out) out = F.relu(self.lin1(out)) out = self.lin2(out) return out diff --git a/python/dgl/nn/mxnet/conv.py b/python/dgl/nn/mxnet/conv.py index 9acca0ecc309..f792af5331f5 100644 --- a/python/dgl/nn/mxnet/conv.py +++ b/python/dgl/nn/mxnet/conv.py @@ -83,7 +83,7 @@ def __init__(self, self._activation = activation - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute graph convolution. Notes @@ -95,10 +95,10 @@ def forward(self, feat, graph): Parameters ---------- - feat : mxnet.NDArray - The input feature graph : DGLGraph The graph. + feat : mxnet.NDArray + The input feature Returns ------- diff --git a/python/dgl/nn/mxnet/glob.py b/python/dgl/nn/mxnet/glob.py index f3b89e775e03..37053c8889e4 100644 --- a/python/dgl/nn/mxnet/glob.py +++ b/python/dgl/nn/mxnet/glob.py @@ -19,16 +19,16 @@ class SumPooling(nn.Block): def __init__(self): super(SumPooling, self).__init__() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute sum pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : mxnet.NDArray The input feature with shape :math:`(N, *)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -56,16 +56,16 @@ class AvgPooling(nn.Block): def __init__(self): super(AvgPooling, self).__init__() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute average pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : mxnet.NDArray The input feature with shape :math:`(N, *)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -93,16 +93,16 @@ class MaxPooling(nn.Block): def __init__(self): super(MaxPooling, self).__init__() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute max pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : mxnet.NDArray The input feature with shape :math:`(N, *)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -134,16 +134,16 @@ def __init__(self, k): super(SortPooling, self).__init__() self.k = k - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute sort pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : mxnet.NDArray The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -190,16 +190,16 @@ def __init__(self, gate_nn, feat_nn=None): self.gate_nn = gate_nn self.feat_nn = feat_nn - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute global attention pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : mxnet.NDArray The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -258,16 +258,16 @@ def __init__(self, input_dim, n_iters, n_layers): self.lstm = gluon.rnn.LSTM( self.input_dim, num_layers=n_layers, input_size=self.output_dim) - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute set2set pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : mxnet.NDArray The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- diff --git a/python/dgl/nn/pytorch/conv.py b/python/dgl/nn/pytorch/conv.py index 3ff769b59f9a..825ad0e7a1a2 100644 --- a/python/dgl/nn/pytorch/conv.py +++ b/python/dgl/nn/pytorch/conv.py @@ -107,7 +107,7 @@ def reset_parameters(self): if self.bias is not None: init.zeros_(self.bias) - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute graph convolution. Notes @@ -119,10 +119,10 @@ def forward(self, feat, graph): Parameters ---------- - feat : torch.Tensor - The input feature graph : DGLGraph The graph. + feat : torch.Tensor + The input feature Returns ------- @@ -246,16 +246,16 @@ def reset_parameters(self): if isinstance(self.res_fc, nn.Linear): nn.init.xavier_normal_(self.res_fc.weight, gain=gain) - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute graph attention network layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. - graph : DGLGraph - The graph. Returns ------- @@ -338,16 +338,16 @@ def reset_parameters(self): gain = nn.init.calculate_gain('relu') nn.init.xavier_normal_(self.lin.weight, gain=gain) - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute topology adaptive graph convolution. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. - graph : DGLGraph - The graph. Returns ------- @@ -643,16 +643,16 @@ def _lstm_reducer(self, nodes): _, (rst, _) = self.lstm(m, h) return {'neigh': rst.squeeze(0)} - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute GraphSAGE layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. - graph : DGLGraph - The graph. Returns ------- @@ -742,11 +742,13 @@ def reset_parameters(self): self.gru.reset_parameters() init.xavier_normal_(self.edge_embed.weight, gain=gain) - def forward(self, feat, etypes, graph): + def forward(self, graph, feat, etypes): """Compute Gated Graph Convolution layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`N` is the number of nodes of the graph and :math:`D_{in}` is the @@ -754,8 +756,6 @@ def forward(self, feat, etypes, graph): etypes : torch.LongTensor The edge type tensor of shape :math:`(E,)` where :math:`E` is the number of edges of the graph. - graph : DGLGraph - The graph. Returns ------- @@ -856,11 +856,13 @@ def reset_parameters(self): if self.bias is not None: init.zeros_(self.bias.data) - def forward(self, feat, pseudo, graph): + def forward(self, graph, feat, pseudo): """Compute Gaussian Mixture Model Convolution layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`N` is the number of nodes of the graph and :math:`D_{in}` is the @@ -869,8 +871,6 @@ def forward(self, feat, pseudo, graph): The pseudo coordinate tensor of shape :math:`(E, D_{u})` where :math:`E` is the number of edges of the graph and :math:`D_{u}` is the dimensionality of pseudo coordinate. - graph : DGLGraph - The graph. Returns ------- @@ -940,18 +940,18 @@ def __init__(self, else: self.register_buffer('eps', th.FloatTensor([init_eps])) - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute Graph Isomorphism Network layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D)` where :math:`D` could be any positive integer, :math:`N` is the number of nodes. If ``apply_func`` is not None, :math:`D` should fit the input dimensionality requirement of ``apply_func``. - graph : DGLGraph - The graph. Returns ------- @@ -1025,16 +1025,22 @@ def reset_parameters(self): if module.bias is not None: init.zeros_(module.bias) - def forward(self, feat, graph, lambda_max=None): + def forward(self, graph, feat, lambda_max=None): r"""Compute ChebNet layer. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. - graph : DGLGraph - The graph. + lambda_max : list or tensor or None, optional. + A list(tensor) with length :math:`B`, stores the largest eigenvalue + of the normalized laplacian of each individual graph in ``graph``, + where :math:`B` is the batch size of the input graph. Default: None. + If None, this method would compute the list by calling + ``dgl.laplacian_lambda_max``. Returns ------- @@ -1047,13 +1053,13 @@ def forward(self, feat, graph, lambda_max=None): graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device) if lambda_max is None: lambda_max = laplacian_lambda_max(graph) - lambda_max = th.Tensor(lambda_max).to(feat.device) + if isinstance(lambda_max, list): + lambda_max = th.Tensor(lambda_max).to(feat.device) if lambda_max.dim() < 1: lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) # broadcast from (B, 1) to (N, 1) lambda_max = broadcast_nodes(graph, lambda_max) # T0(X) - Tx_0 = feat rst = self.fc[0](Tx_0) # T1(X) @@ -1125,16 +1131,16 @@ def __init__(self, self._k = k self.norm = norm - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute Simplifying Graph Convolution layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. - graph : DGLGraph - The graph. Returns ------- @@ -1241,11 +1247,13 @@ def reset_parameters(self): if isinstance(self.res_fc, nn.Linear): nn.init.xavier_normal_(self.res_fc.weight, gain=gain) - def forward(self, feat, efeat, graph): + def forward(self, graph, feat, efeat): r"""Compute MPNN Graph Convolution layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`N` is the number of nodes of the graph and :math:`D_{in}` is the @@ -1253,8 +1261,6 @@ def forward(self, feat, efeat, graph): efeat : torch.Tensor The edge feature of shape :math:`(N, *)`, should fit the input shape requirement of ``edge_nn``. - graph : DGLGraph - The graph. Returns ------- @@ -1309,16 +1315,16 @@ def __init__(self, self._alpha = alpha self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute APPNP layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, *)` :math:`N` is the number of nodes, and :math:`*` could be of any shape. - graph : DGLGraph - The graph. Returns ------- @@ -1374,16 +1380,16 @@ def __init__(self, else: self.register_buffer('beta', th.Tensor([init_beta])) - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute AGNN layer. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature of shape :math:`(N, *)` :math:`N` is the number of nodes, and :math:`*` could be of any shape. - graph : DGLGraph - The graph. Returns ------- @@ -1452,18 +1458,18 @@ def reset_parameters(self): if self.bias is not None: init.zeros_(self.bias) - def forward(self, feat, adj): + def forward(self, adj, feat): r"""Compute (Dense) Graph Convolution layer. Parameters ---------- - feat : torch.Tensor - The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` - is size of input feature, :math:`N` is the number of nodes. adj : torch.Tensor The adjacency matrix of the graph to apply Graph Convolution on, should be of shape :math:`(N, N)`, where a row represents the destination and a column represents the source. + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. Returns ------- @@ -1549,18 +1555,18 @@ def reset_parameters(self): gain = nn.init.calculate_gain('relu') nn.init.xavier_uniform_(self.fc.weight, gain=gain) - def forward(self, feat, adj): + def forward(self, adj, feat): r"""Compute (Dense) Graph SAGE layer. Parameters ---------- - feat : torch.Tensor - The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` - is size of input feature, :math:`N` is the number of nodes. adj : torch.Tensor The adjacency matrix of the graph to apply Graph Convolution on, should be of shape :math:`(N, N)`, where a row represents the destination and a column represents the source. + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. Returns ------- @@ -1629,18 +1635,21 @@ def reset_parameters(self): for i in range(self._k): init.xavier_normal_(self.W[i], init.calculate_gain('relu')) - def forward(self, feat, adj): + def forward(self, adj, feat, lambda_max=None): r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer. Parameters ---------- - feat : torch.Tensor - The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` - is size of input feature, :math:`N` is the number of nodes. adj : torch.Tensor The adjacency matrix of the graph to apply Graph Convolution on, should be of shape :math:`(N, N)`, where a row represents the destination and a column represents the source. + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + lambda_max : float or None, optional + A float value indicates the largest eigenvalue of given graph. + Default: None. Returns ------- @@ -1656,10 +1665,11 @@ def forward(self, feat, adj): I = th.eye(num_nodes).to(A) L = I - D_invsqrt @ A @ D_invsqrt - lambda_ = th.eig(L)[0][:, 0] - lambda_max = lambda_.max() - L_hat = 2 * L / lambda_max - I + if lambda_max is None: + lambda_ = th.eig(L)[0][:, 0] + lambda_max = lambda_.max() + L_hat = 2 * L / lambda_max - I Z = [th.eye(num_nodes).to(A)] for i in range(1, self._k): if i == 1: diff --git a/python/dgl/nn/pytorch/glob.py b/python/dgl/nn/pytorch/glob.py index b925faf8f839..cb8707059155 100644 --- a/python/dgl/nn/pytorch/glob.py +++ b/python/dgl/nn/pytorch/glob.py @@ -23,17 +23,17 @@ class SumPooling(nn.Module): def __init__(self): super(SumPooling, self).__init__() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute sum pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, *)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -57,16 +57,16 @@ class AvgPooling(nn.Module): def __init__(self): super(AvgPooling, self).__init__() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute average pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, *)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -90,16 +90,16 @@ class MaxPooling(nn.Module): def __init__(self): super(MaxPooling, self).__init__() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute max pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, *)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -127,16 +127,16 @@ def __init__(self, k): super(SortPooling, self).__init__() self.k = k - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute sort pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -179,16 +179,16 @@ def __init__(self, gate_nn, feat_nn=None): self.gate_nn = gate_nn self.feat_nn = feat_nn - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute global attention pooling. Parameters ---------- + graph : DGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph - The graph. Returns ------- @@ -252,16 +252,16 @@ def reset_parameters(self): """Reinitialize learnable parameters.""" self.lstm.reset_parameters() - def forward(self, feat, graph): + def forward(self, graph, feat): r"""Compute set2set pooling. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -568,17 +568,17 @@ def __init__(self, d_model, n_heads, d_head, d_ff, self.layers = nn.ModuleList(layers) - def forward(self, feat, graph): + def forward(self, graph, feat): """ Compute the Encoder part of Set Transformer. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- @@ -634,17 +634,17 @@ def __init__(self, d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0., d self.layers = nn.ModuleList(layers) - def forward(self, feat, graph): + def forward(self, graph, feat): """ Compute the decoder part of Set Transformer. Parameters ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. feat : torch.Tensor The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph. - graph : DGLGraph or BatchedDGLGraph - The graph. Returns ------- diff --git a/tests/mxnet/test_nn.py b/tests/mxnet/test_nn.py index 6cf58132a783..bd2b088156bd 100644 --- a/tests/mxnet/test_nn.py +++ b/tests/mxnet/test_nn.py @@ -24,13 +24,13 @@ def test_graph_conv(): conv.initialize(ctx=ctx) # test#1: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) # test#2: more-dim h0 = F.ones((3, 5, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) @@ -40,12 +40,12 @@ def test_graph_conv(): # test#3: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 # test#4: basic h0 = F.ones((3, 5, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 @@ -55,18 +55,18 @@ def test_graph_conv(): with autograd.train_mode(): # test#3: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 # test#4: basic h0 = F.ones((3, 5, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 # test not override features g.ndata["h"] = 2 * F.ones((3, 1)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 1 assert len(g.edata) == 0 assert "h" in g.ndata @@ -82,13 +82,13 @@ def test_set2set(): # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) - h1 = s2s(h0, g) + h1 = s2s(g, h0) assert h1.shape[0] == 10 and h1.ndim == 1 # test#2: batched graph bg = dgl.batch([g, g, g]) h0 = F.randn((bg.number_of_nodes(), 5)) - h1 = s2s(h0, bg) + h1 = s2s(bg, h0) assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2 def test_glob_att_pool(): @@ -100,13 +100,13 @@ def test_glob_att_pool(): print(gap) # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) - h1 = gap(h0, g) + h1 = gap(g, h0) assert h1.shape[0] == 10 and h1.ndim == 1 # test#2: batched graph bg = dgl.batch([g, g, g, g]) h0 = F.randn((bg.number_of_nodes(), 5)) - h1 = gap(h0, bg) + h1 = gap(bg, h0) assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 def test_simple_pool(): @@ -120,20 +120,20 @@ def test_simple_pool(): # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) - h1 = sum_pool(h0, g) + h1 = sum_pool(g, h0) check_close(h1, F.sum(h0, 0)) - h1 = avg_pool(h0, g) + h1 = avg_pool(g, h0) check_close(h1, F.mean(h0, 0)) - h1 = max_pool(h0, g) + h1 = max_pool(g, h0) check_close(h1, F.max(h0, 0)) - h1 = sort_pool(h0, g) + h1 = sort_pool(g, h0) assert h1.shape[0] == 10 * 5 and h1.ndim == 1 # test#2: batched graph g_ = dgl.DGLGraph(nx.path_graph(5)) bg = dgl.batch([g, g_, g, g_, g]) h0 = F.randn((bg.number_of_nodes(), 5)) - h1 = sum_pool(h0, bg) + h1 = sum_pool(bg, h0) truth = mx.nd.stack(F.sum(h0[:15], 0), F.sum(h0[15:20], 0), F.sum(h0[20:35], 0), @@ -141,7 +141,7 @@ def test_simple_pool(): F.sum(h0[40:55], 0), axis=0) check_close(h1, truth) - h1 = avg_pool(h0, bg) + h1 = avg_pool(bg, h0) truth = mx.nd.stack(F.mean(h0[:15], 0), F.mean(h0[15:20], 0), F.mean(h0[20:35], 0), @@ -149,7 +149,7 @@ def test_simple_pool(): F.mean(h0[40:55], 0), axis=0) check_close(h1, truth) - h1 = max_pool(h0, bg) + h1 = max_pool(bg, h0) truth = mx.nd.stack(F.max(h0[:15], 0), F.max(h0[15:20], 0), F.max(h0[20:35], 0), @@ -157,7 +157,7 @@ def test_simple_pool(): F.max(h0[40:55], 0), axis=0) check_close(h1, truth) - h1 = sort_pool(h0, bg) + h1 = sort_pool(bg, h0) assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2 def uniform_attention(g, shape): diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 58b5c8698c55..7c9ae57a4e7b 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -24,13 +24,13 @@ def test_graph_conv(): print(conv) # test#1: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) # test#2: more-dim h0 = F.ones((3, 5, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) @@ -40,12 +40,12 @@ def test_graph_conv(): conv = conv.to(ctx) # test#3: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 # test#4: basic h0 = F.ones((3, 5, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 @@ -54,12 +54,12 @@ def test_graph_conv(): conv = conv.to(ctx) # test#3: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 # test#4: basic h0 = F.ones((3, 5, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 @@ -94,7 +94,7 @@ def test_tagconv(): # test#1: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert len(g.ndata) == 0 assert len(g.edata) == 0 shp = norm.shape + (1,) * (h0.dim() - 1) @@ -107,7 +107,7 @@ def test_tagconv(): conv = conv.to(ctx) # test#2: basic h0 = F.ones((3, 5)) - h1 = conv(h0, g) + h1 = conv(g, h0) assert h1.shape[-1] == 2 # test reset_parameters @@ -127,7 +127,7 @@ def test_set2set(): # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) - h1 = s2s(h0, g) + h1 = s2s(g, h0) assert h1.shape[0] == 10 and h1.dim() == 1 # test#2: batched graph @@ -135,7 +135,7 @@ def test_set2set(): g2 = dgl.DGLGraph(nx.path_graph(5)) bg = dgl.batch([g, g1, g2]) h0 = F.randn((bg.number_of_nodes(), 5)) - h1 = s2s(h0, bg) + h1 = s2s(bg, h0) assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2 def test_glob_att_pool(): @@ -149,13 +149,13 @@ def test_glob_att_pool(): # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) - h1 = gap(h0, g) + h1 = gap(g, h0) assert h1.shape[0] == 10 and h1.dim() == 1 # test#2: batched graph bg = dgl.batch([g, g, g, g]) h0 = F.randn((bg.number_of_nodes(), 5)) - h1 = gap(h0, bg) + h1 = gap(bg, h0) assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2 def test_simple_pool(): @@ -176,13 +176,13 @@ def test_simple_pool(): max_pool = max_pool.to(ctx) sort_pool = sort_pool.to(ctx) h0 = h0.to(ctx) - h1 = sum_pool(h0, g) + h1 = sum_pool(g, h0) assert F.allclose(h1, F.sum(h0, 0)) - h1 = avg_pool(h0, g) + h1 = avg_pool(g, h0) assert F.allclose(h1, F.mean(h0, 0)) - h1 = max_pool(h0, g) + h1 = max_pool(g, h0) assert F.allclose(h1, F.max(h0, 0)) - h1 = sort_pool(h0, g) + h1 = sort_pool(g, h0) assert h1.shape[0] == 10 * 5 and h1.dim() == 1 # test#2: batched graph @@ -192,7 +192,7 @@ def test_simple_pool(): if F.gpu_ctx(): h0 = h0.to(ctx) - h1 = sum_pool(h0, bg) + h1 = sum_pool(bg, h0) truth = th.stack([F.sum(h0[:15], 0), F.sum(h0[15:20], 0), F.sum(h0[20:35], 0), @@ -200,7 +200,7 @@ def test_simple_pool(): F.sum(h0[40:55], 0)], 0) assert F.allclose(h1, truth) - h1 = avg_pool(h0, bg) + h1 = avg_pool(bg, h0) truth = th.stack([F.mean(h0[:15], 0), F.mean(h0[15:20], 0), F.mean(h0[20:35], 0), @@ -208,7 +208,7 @@ def test_simple_pool(): F.mean(h0[40:55], 0)], 0) assert F.allclose(h1, truth) - h1 = max_pool(h0, bg) + h1 = max_pool(bg, h0) truth = th.stack([F.max(h0[:15], 0), F.max(h0[15:20], 0), F.max(h0[20:35], 0), @@ -216,7 +216,7 @@ def test_simple_pool(): F.max(h0[40:55], 0)], 0) assert F.allclose(h1, truth) - h1 = sort_pool(h0, bg) + h1 = sort_pool(bg, h0) assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2 def test_set_trans(): @@ -234,11 +234,11 @@ def test_set_trans(): # test#1: basic h0 = F.randn((g.number_of_nodes(), 50)) - h1 = st_enc_0(h0, g) + h1 = st_enc_0(g, h0) assert h1.shape == h0.shape - h1 = st_enc_1(h0, g) + h1 = st_enc_1(g, h0) assert h1.shape == h0.shape - h2 = st_dec(h1, g) + h2 = st_dec(g, h1) assert h2.shape[0] == 200 and h2.dim() == 1 # test#2: batched graph @@ -246,12 +246,12 @@ def test_set_trans(): g2 = dgl.DGLGraph(nx.path_graph(10)) bg = dgl.batch([g, g1, g2]) h0 = F.randn((bg.number_of_nodes(), 50)) - h1 = st_enc_0(h0, bg) + h1 = st_enc_0(bg, h0) assert h1.shape == h0.shape - h1 = st_enc_1(h0, bg) + h1 = st_enc_1(bg, h0) assert h1.shape == h0.shape - h2 = st_dec(h1, bg) + h2 = st_dec(bg, h1) assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2 def uniform_attention(g, shape): @@ -375,7 +375,7 @@ def test_gat_conv(): gat = gat.to(ctx) feat = feat.to(ctx) - h = gat(feat, g) + h = gat(g, feat) assert h.shape[-1] == 2 and h.shape[-2] == 4 def test_sage_conv(): @@ -389,7 +389,7 @@ def test_sage_conv(): sage = sage.to(ctx) feat = feat.to(ctx) - h = sage(feat, g) + h = sage(g, feat) assert h.shape[-1] == 10 def test_sgc_conv(): @@ -403,7 +403,7 @@ def test_sgc_conv(): sgc = sgc.to(ctx) feat = feat.to(ctx) - h = sgc(feat, g) + h = sgc(g, feat) assert h.shape[-1] == 10 # cached @@ -412,8 +412,8 @@ def test_sgc_conv(): if F.gpu_ctx(): sgc = sgc.to(ctx) - h_0 = sgc(feat, g) - h_1 = sgc(feat + 1, g) + h_0 = sgc(g, feat) + h_1 = sgc(g, feat + 1) assert F.allclose(h_0, h_1) assert h_0.shape[-1] == 10 @@ -427,7 +427,7 @@ def test_appnp_conv(): appnp = appnp.to(ctx) feat = feat.to(ctx) - h = appnp(feat, g) + h = appnp(g, feat) assert h.shape[-1] == 5 def test_gin_conv(): @@ -444,7 +444,7 @@ def test_gin_conv(): gin = gin.to(ctx) feat = feat.to(ctx) - h = gin(feat, g) + h = gin(g, feat) assert h.shape[-1] == 12 def test_agnn_conv(): @@ -457,7 +457,7 @@ def test_agnn_conv(): agnn = agnn.to(ctx) feat = feat.to(ctx) - h = agnn(feat, g) + h = agnn(g, feat) assert h.shape[-1] == 5 def test_gated_graph_conv(): @@ -472,7 +472,7 @@ def test_gated_graph_conv(): feat = feat.to(ctx) etypes = etypes.to(ctx) - h = ggconv(feat, etypes, g) + h = ggconv(g, feat, etypes) # current we only do shape check assert h.shape[-1] == 10 @@ -489,7 +489,7 @@ def test_nn_conv(): feat = feat.to(ctx) efeat = efeat.to(ctx) - h = nnconv(feat, efeat, g) + h = nnconv(g, feat, efeat) # currently we only do shape check assert h.shape[-1] == 10 @@ -505,7 +505,7 @@ def test_gmm_conv(): feat = feat.to(ctx) pseudo = pseudo.to(ctx) - h = gmmconv(feat, pseudo, g) + h = gmmconv(g, feat, pseudo) # currently we only do shape check assert h.shape[-1] == 10 @@ -523,8 +523,8 @@ def test_dense_graph_conv(): dense_conv = dense_conv.to(ctx) feat = feat.to(ctx) - out_conv = conv(feat, g) - out_dense_conv = dense_conv(feat, adj) + out_conv = conv(g, feat) + out_dense_conv = dense_conv(adj, feat) assert F.allclose(out_conv, out_dense_conv) def test_dense_sage_conv(): @@ -541,8 +541,8 @@ def test_dense_sage_conv(): dense_sage = dense_sage.to(ctx) feat = feat.to(ctx) - out_sage = sage(feat, g) - out_dense_sage = dense_sage(feat, adj) + out_sage = sage(g, feat) + out_dense_sage = dense_sage(adj, feat) assert F.allclose(out_sage, out_dense_sage) def test_dense_cheb_conv(): @@ -562,8 +562,8 @@ def test_dense_cheb_conv(): dense_cheb = dense_cheb.to(ctx) feat = feat.to(ctx) - out_cheb = cheb(feat, g) - out_dense_cheb = dense_cheb(feat, adj) + out_cheb = cheb(g, feat, [2.0]) + out_dense_cheb = dense_cheb(adj, feat, 2.0) assert F.allclose(out_cheb, out_dense_cheb) if __name__ == '__main__':