diff --git a/docs/source/api/python/function.rst b/docs/source/api/python/function.rst index 31e81f3d8467..869eb50db95e 100644 --- a/docs/source/api/python/function.rst +++ b/docs/source/api/python/function.rst @@ -14,6 +14,32 @@ Message functions copy_src copy_edge src_mul_edge + copy_u + copy_e + u_add_v + u_sub_v + u_mul_v + u_div_v + u_add_e + u_sub_e + u_mul_e + u_div_e + v_add_u + v_sub_u + v_mul_u + v_div_u + v_add_e + v_sub_e + v_mul_e + v_div_e + e_add_u + e_sub_u + e_mul_u + e_div_u + e_add_v + e_sub_v + e_mul_v + e_div_v Reduce functions ---------------- @@ -23,3 +49,6 @@ Reduce functions sum max + min + prod + mean diff --git a/docs/source/api/python/nn.pytorch.rst b/docs/source/api/python/nn.pytorch.rst index cf7535df8ca8..acb1018e3603 100644 --- a/docs/source/api/python/nn.pytorch.rst +++ b/docs/source/api/python/nn.pytorch.rst @@ -16,6 +16,62 @@ dgl.nn.pytorch.conv :members: forward :show-inheritance: +.. autoclass:: dgl.nn.pytorch.conv.TAGConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.GATConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.SAGEConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.SGConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.APPNPConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.GINConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.GatedGraphConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.GMMConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.ChebConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.AGNNConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.NNConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.DenseGraphConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.DenseSAGEConv + :members: forward + :show-inheritance: + +.. autoclass:: dgl.nn.pytorch.conv.DenseChebConv + :members: forward + :show-inheritance: + dgl.nn.pytorch.glob ------------------- .. automodule:: dgl.nn.pytorch.glob diff --git a/docs/source/api/python/transform.rst b/docs/source/api/python/transform.rst index ebed453857ce..8b0528e4a8d9 100644 --- a/docs/source/api/python/transform.rst +++ b/docs/source/api/python/transform.rst @@ -12,3 +12,6 @@ Transform -- Graph Transformation reverse to_simple_graph to_bidirected + khop_adj + khop_graph + laplacian_lambda_max diff --git a/examples/pytorch/appnp/appnp.py b/examples/pytorch/appnp/appnp.py index d9ca9f1043f0..5a055789255a 100644 --- a/examples/pytorch/appnp/appnp.py +++ b/examples/pytorch/appnp/appnp.py @@ -8,44 +8,7 @@ import torch import torch.nn as nn import dgl.function as fn - - -class GraphPropagation(nn.Module): - def __init__(self, - g, - edge_drop, - alpha, - k): - super(GraphPropagation, self).__init__() - self.g = g - self.alpha = alpha - self.k = k - if edge_drop: - self.edge_drop = nn.Dropout(edge_drop) - else: - self.edge_drop = 0. - - def forward(self, h): - self.cached_h = h - for _ in range(self.k): - # normalization by square root of src degree - h = h * self.g.ndata['norm'] - self.g.ndata['h'] = h - if self.edge_drop: - # performing edge dropout - ed = self.edge_drop(torch.ones((self.g.number_of_edges(), 1), device=h.device)) - self.g.edata['d'] = ed - self.g.update_all(fn.src_mul_edge(src='h', edge='d', out='m'), - fn.sum(msg='m', out='h')) - else: - self.g.update_all(fn.copy_src(src='h', out='m'), - fn.sum(msg='m', out='h')) - h = self.g.ndata.pop('h') - # normalization by square root of dst degree - h = h * self.g.ndata['norm'] - # update h using teleport probability alpha - h = h * (1 - self.alpha) + self.cached_h * self.alpha - return h +from dgl.nn.pytorch.conv import APPNPConv class APPNP(nn.Module): @@ -60,6 +23,7 @@ def __init__(self, alpha, k): super(APPNP, self).__init__() + self.g = g self.layers = nn.ModuleList() # input layer self.layers.append(nn.Linear(in_feats, hiddens[0])) @@ -73,7 +37,7 @@ def __init__(self, self.feat_drop = nn.Dropout(feat_drop) else: self.feat_drop = lambda x: x - self.propagate = GraphPropagation(g, edge_drop, alpha, k) + self.propagate = APPNPConv(k, alpha, edge_drop) self.reset_parameters() def reset_parameters(self): @@ -89,5 +53,5 @@ def forward(self, features): h = self.activation(layer(h)) h = self.layers[-1](self.feat_drop(h)) # propagation step - h = self.propagate(h) + h = self.propagate(h, self.g) return h diff --git a/examples/pytorch/gat/gat.py b/examples/pytorch/gat/gat.py index b1ed34735606..2103961bef24 100644 --- a/examples/pytorch/gat/gat.py +++ b/examples/pytorch/gat/gat.py @@ -10,78 +10,8 @@ import torch import torch.nn as nn import dgl.function as fn -from dgl.nn.pytorch import edge_softmax +from dgl.nn.pytorch import edge_softmax, GATConv -class GraphAttention(nn.Module): - def __init__(self, - g, - in_dim, - out_dim, - num_heads, - feat_drop, - attn_drop, - alpha, - residual=False): - super(GraphAttention, self).__init__() - self.g = g - self.num_heads = num_heads - self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) - if feat_drop: - self.feat_drop = nn.Dropout(feat_drop) - else: - self.feat_drop = lambda x : x - if attn_drop: - self.attn_drop = nn.Dropout(attn_drop) - else: - self.attn_drop = lambda x : x - self.attn_l = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim))) - self.attn_r = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim))) - nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) - nn.init.xavier_normal_(self.attn_l.data, gain=1.414) - nn.init.xavier_normal_(self.attn_r.data, gain=1.414) - self.leaky_relu = nn.LeakyReLU(alpha) - self.softmax = edge_softmax - self.residual = residual - if residual: - if in_dim != out_dim: - self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) - nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414) - else: - self.res_fc = None - - def forward(self, inputs): - # prepare - h = self.feat_drop(inputs) # NxD - ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' - a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1) # N x H x 1 - a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1) # N x H x 1 - self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) - # 1. compute edge attention - self.g.apply_edges(self.edge_attention) - # 2. compute softmax - self.edge_softmax() - # 3. compute the aggregated node features scaled by the dropped, - # unnormalized attention values. - self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) - ret = self.g.ndata['ft'] - # 4. residual - if self.residual: - if self.res_fc is not None: - resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' - else: - resval = torch.unsqueeze(h, 1) # Nx1xD' - ret = resval + ret - return ret - - def edge_attention(self, edges): - # an edge UDF to compute unnormalized attention values from src and dst - a = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) - return {'a' : a} - - def edge_softmax(self): - attention = self.softmax(self.g, self.g.edata.pop('a')) - # Dropout attention scores and save them - self.g.edata['a_drop'] = self.attn_drop(attention) class GAT(nn.Module): def __init__(self, @@ -94,7 +24,7 @@ def __init__(self, activation, feat_drop, attn_drop, - alpha, + negative_slope, residual): super(GAT, self).__init__() self.g = g @@ -102,24 +32,24 @@ def __init__(self, self.gat_layers = nn.ModuleList() self.activation = activation # input projection (no residual) - self.gat_layers.append(GraphAttention( - g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False)) + self.gat_layers.append(GATConv( + in_dim, num_hidden, heads[0], + feat_drop, attn_drop, negative_slope, False, self.activation)) # hidden layers for l in range(1, num_layers): # due to multi-head, the in_dim = num_hidden * num_heads - self.gat_layers.append(GraphAttention( - g, num_hidden * heads[l-1], num_hidden, heads[l], - feat_drop, attn_drop, alpha, residual)) + self.gat_layers.append(GATConv( + num_hidden * heads[l-1], num_hidden, heads[l], + feat_drop, attn_drop, negative_slope, residual, self.activation)) # output projection - self.gat_layers.append(GraphAttention( - g, num_hidden * heads[-2], num_classes, heads[-1], - feat_drop, attn_drop, alpha, residual)) + self.gat_layers.append(GATConv( + num_hidden * heads[-2], num_classes, heads[-1], + feat_drop, attn_drop, negative_slope, residual, None)) def forward(self, inputs): h = inputs for l in range(self.num_layers): - h = self.gat_layers[l](h).flatten(1) - h = self.activation(h) + h = self.gat_layers[l](h, self.g).flatten(1) # output projection - logits = self.gat_layers[-1](h).mean(1) + logits = self.gat_layers[-1](h, self.g).mean(1) return logits diff --git a/examples/pytorch/gat/train.py b/examples/pytorch/gat/train.py index 463142f050ab..e77d8fc3bfcc 100644 --- a/examples/pytorch/gat/train.py +++ b/examples/pytorch/gat/train.py @@ -86,7 +86,7 @@ def main(args): F.elu, args.in_drop, args.attn_drop, - args.alpha, + args.negative_slope, args.residual) print(model) stopper = EarlyStopping(patience=100) @@ -161,8 +161,8 @@ def main(args): help="learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4, help="weight decay") - parser.add_argument('--alpha', type=float, default=0.2, - help="the negative slop of leaky relu") + parser.add_argument('--negative-slope', type=float, default=0.2, + help="the negative slope of leaky relu") parser.add_argument('--fastmode', action="store_true", default=False, help="skip re-evaluate the validation set") args = parser.parse_args() diff --git a/examples/pytorch/gcn/train.py b/examples/pytorch/gcn/train.py index bbfeb57accfb..be90b20c5ef4 100644 --- a/examples/pytorch/gcn/train.py +++ b/examples/pytorch/gcn/train.py @@ -57,8 +57,8 @@ def main(args): g = data.graph # add self loop if args.self_loop: - g.remove_edges_from(g.selfloop_edges()) - g.add_edges_from(zip(g.nodes(), g.nodes())) + g.remove_edges_from(g.selfloop_edges()) + g.add_edges_from(zip(g.nodes(), g.nodes())) g = DGLGraph(g) n_edges = g.number_of_edges() # normalization diff --git a/examples/pytorch/gin/gin.py b/examples/pytorch/gin/gin.py index d9a41cc974dd..1260020675de 100644 --- a/examples/pytorch/gin/gin.py +++ b/examples/pytorch/gin/gin.py @@ -9,72 +9,24 @@ import torch import torch.nn as nn import torch.nn.functional as F +from dgl.nn.pytorch.conv import GINConv import dgl import dgl.function as fn -# Sends a message of node feature h. -msg = fn.copy_src(src='h', out='m') -reduce_sum = fn.sum(msg='m', out='h') -reduce_max = fn.max(msg='m', out='h') - - -def reduce_mean(nodes): - return {'h': torch.mean(nodes.mailbox['m'], dim=1)[0]} - - -class ApplyNodes(nn.Module): +class ApplyNodeFunc(nn.Module): """Update the node feature hv with MLP, BN and ReLU.""" - def __init__(self, mlp, layer): - super(ApplyNodes, self).__init__() + def __init__(self, mlp): + super(ApplyNodeFunc, self).__init__() self.mlp = mlp self.bn = nn.BatchNorm1d(self.mlp.output_dim) - self.layer = layer - def forward(self, nodes): - h = self.mlp(nodes.data['h']) + def forward(self, h): + h = self.mlp(h) h = self.bn(h) h = F.relu(h) - - return {'h': h} - - -class GINLayer(nn.Module): - """Neighbor pooling and reweight nodes before send graph into MLP""" - def __init__(self, eps, layer, mlp, neighbor_pooling_type, learn_eps): - super(GINLayer, self).__init__() - self.bn = nn.BatchNorm1d(mlp.output_dim) - self.neighbor_pooling_type = neighbor_pooling_type - self.eps = eps - self.learn_eps = learn_eps - self.layer = layer - self.apply_mod = ApplyNodes(mlp, layer) - - def forward(self, g, feature): - g.ndata['h'] = feature - - if self.neighbor_pooling_type == 'sum': - reduce_func = reduce_sum - elif self.neighbor_pooling_type == 'mean': - reduce_func = reduce_mean - elif self.neighbor_pooling_type == 'max': - reduce_func = reduce_max - else: - raise NotImplementedError() - - h = feature # h0 - g.update_all(msg, reduce_func) - pooled = g.ndata['h'] - - # reweight the center node when aggregating it with its neighbors - if self.learn_eps: - pooled = pooled + (1 + self.eps[self.layer])*h - - g.ndata['h'] = pooled - g.apply_nodes(func=self.apply_mod) - - return g.ndata.pop('h') + return h class MLP(nn.Module): @@ -168,7 +120,6 @@ def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, self.num_layers = num_layers self.graph_pooling_type = graph_pooling_type self.learn_eps = learn_eps - self.eps = nn.Parameter(torch.zeros(self.num_layers - 1)) # List of MLPs self.ginlayers = torch.nn.ModuleList() @@ -180,8 +131,8 @@ def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, else: mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) - self.ginlayers.append(GINLayer( - self.eps, layer, mlp, neighbor_pooling_type, self.learn_eps)) + self.ginlayers.append( + GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # Linear function for graph poolings of output of each layer @@ -204,7 +155,7 @@ def forward(self, g): hidden_rep = [h] for layer in range(self.num_layers - 1): - h = self.ginlayers[layer](g, h) + h = self.ginlayers[layer](h, g) hidden_rep.append(h) score_over_layer = 0 diff --git a/examples/pytorch/gin/main.py b/examples/pytorch/gin/main.py index af87ae4f9727..e2d81967efa9 100644 --- a/examples/pytorch/gin/main.py +++ b/examples/pytorch/gin/main.py @@ -148,7 +148,7 @@ def main(args): lrbar.set_description( "the learning eps with learn_eps={} is: {}".format( - args.learn_eps, model.eps.data)) + args.learn_eps, [layer.eps.data for layer in model.ginlayers])) tbar.close() vbar.close() diff --git a/examples/pytorch/graphsage/README.md b/examples/pytorch/graphsage/README.md index eae32929fb73..984415132de3 100644 --- a/examples/pytorch/graphsage/README.md +++ b/examples/pytorch/graphsage/README.md @@ -22,6 +22,6 @@ Run with following (available dataset: "cora", "citeseer", "pubmed") python3 graphsage.py --dataset cora --gpu 0 ``` -* cora: ~0.8470 -* citeseer: ~0.6870 -* pubmed: ~0.7730 +* cora: ~0.8330 +* citeseer: ~0.7110 +* pubmed: ~0.7830 diff --git a/examples/pytorch/graphsage/graphsage.py b/examples/pytorch/graphsage/graphsage.py index c65fac19d7fe..dbaabeb35dba 100644 --- a/examples/pytorch/graphsage/graphsage.py +++ b/examples/pytorch/graphsage/graphsage.py @@ -13,100 +13,7 @@ import torch.nn.functional as F from dgl import DGLGraph from dgl.data import register_data_args, load_data -import dgl.function as fn - - -class Aggregator(nn.Module): - def __init__(self, g, in_feats, out_feats, activation=None, bias=True): - super(Aggregator, self).__init__() - self.g = g - self.linear = nn.Linear(in_feats, out_feats, bias=bias) # (F, EF) or (2F, EF) - self.activation = activation - nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) - - def forward(self, node): - nei = node.mailbox['m'] # (B, N, F) - h = node.data['h'] # (B, F) - h = self.concat(h, nei, node) # (B, F) or (B, 2F) - h = self.linear(h) # (B, EF) - if self.activation: - h = self.activation(h) - norm = torch.pow(h, 2) - norm = torch.sum(norm, 1, keepdim=True) - norm = torch.pow(norm, -0.5) - norm[torch.isinf(norm)] = 0 - # h = h * norm - return {'h': h} - - @abc.abstractmethod - def concat(self, h, nei, nodes): - raise NotImplementedError - - -class MeanAggregator(Aggregator): - def __init__(self, g, in_feats, out_feats, activation, bias): - super(MeanAggregator, self).__init__(g, in_feats, out_feats, activation, bias) - - def concat(self, h, nei, nodes): - degs = self.g.in_degrees(nodes.nodes()).float() - if h.is_cuda: - degs = degs.cuda(h.device) - concatenate = torch.cat((nei, h.unsqueeze(1)), 1) - concatenate = torch.sum(concatenate, 1) / degs.unsqueeze(1) - return concatenate # (B, F) - - -class PoolingAggregator(Aggregator): - def __init__(self, g, in_feats, out_feats, activation, bias): # (2F, F) - super(PoolingAggregator, self).__init__(g, in_feats*2, out_feats, activation, bias) - self.mlp = PoolingAggregator.MLP(in_feats, in_feats, F.relu, False, True) - - def concat(self, h, nei, nodes): - nei = self.mlp(nei) # (B, F) - concatenate = torch.cat((nei, h), 1) # (B, 2F) - return concatenate - - class MLP(nn.Module): - def __init__(self, in_feats, out_feats, activation, dropout, bias): # (F, F) - super(PoolingAggregator.MLP, self).__init__() - self.linear = nn.Linear(in_feats, out_feats, bias=bias) # (F, F) - self.dropout = nn.Dropout(p=dropout) - self.activation = activation - nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) - - def forward(self, nei): - nei = self.dropout(nei) # (B, N, F) - nei = self.linear(nei) - if self.activation: - nei = self.activation(nei) - max_value = torch.max(nei, dim=1)[0] # (B, F) - return max_value - - -class GraphSAGELayer(nn.Module): - def __init__(self, - g, - in_feats, - out_feats, - activation, - dropout, - aggregator_type, - bias=True, - ): - super(GraphSAGELayer, self).__init__() - self.g = g - self.dropout = nn.Dropout(p=dropout) - if aggregator_type == "pooling": - self.aggregator = PoolingAggregator(g, in_feats, out_feats, activation, bias) - else: - self.aggregator = MeanAggregator(g, in_feats, out_feats, activation, bias) - - def forward(self, h): - h = self.dropout(h) - self.g.ndata['h'] = h - self.g.update_all(fn.copy_src(src='h', out='m'), self.aggregator) - h = self.g.ndata.pop('h') - return h +from dgl.nn.pytorch.conv import SAGEConv class GraphSAGE(nn.Module): @@ -121,19 +28,20 @@ def __init__(self, aggregator_type): super(GraphSAGE, self).__init__() self.layers = nn.ModuleList() + self.g = g # input layer - self.layers.append(GraphSAGELayer(g, in_feats, n_hidden, activation, dropout, aggregator_type)) + self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) # hidden layers for i in range(n_layers - 1): - self.layers.append(GraphSAGELayer(g, n_hidden, n_hidden, activation, dropout, aggregator_type)) + self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) # output layer - self.layers.append(GraphSAGELayer(g, n_hidden, n_classes, None, dropout, aggregator_type)) + self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, feat_drop=dropout, activation=None)) # activation None def forward(self, features): h = features for layer in self.layers: - h = layer(h) + h = layer(h, self.g) return h @@ -182,7 +90,9 @@ def main(args): print("use cuda:", args.gpu) # graph preprocess and calculate normalization factor - g = DGLGraph(data.graph) + g = data.graph + g.remove_edges_from(g.selfloop_edges()) + g = DGLGraph(g) n_edges = g.number_of_edges() # create GraphSAGE model @@ -231,7 +141,7 @@ def main(args): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='GCN') + parser = argparse.ArgumentParser(description='GraphSAGE') register_data_args(parser) parser.add_argument("--dropout", type=float, default=0.5, help="dropout probability") @@ -247,8 +157,8 @@ def main(args): help="number of hidden gcn layers") parser.add_argument("--weight-decay", type=float, default=5e-4, help="Weight for L2 loss") - parser.add_argument("--aggregator-type", type=str, default="mean", - help="Weight for L2 loss") + parser.add_argument("--aggregator-type", type=str, default="gcn", + help="Aggregator type: mean/gcn/pool/lstm") args = parser.parse_args() print(args) diff --git a/examples/pytorch/model_zoo/citation_network/README.md b/examples/pytorch/model_zoo/citation_network/README.md new file mode 100644 index 000000000000..48ff3cc2ec1e --- /dev/null +++ b/examples/pytorch/model_zoo/citation_network/README.md @@ -0,0 +1,30 @@ +# Node Classification on Citation Networks + +This example shows how to use modules defined in `dgl.nn.pytorch.conv` to do node classification on +citation network datasets. + +## Datasets + +- Cora +- Citeseer +- Pubmed + +## Models + +- GCN: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/pdf/1609.02907) +- GAT: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) +- GraphSAGE [Inductive Representation Learning on Large Graphs](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf) +- APPNP: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/pdf/1810.05997) +- GIN: [How Powerful are Graph Neural Networks?](https://arxiv.org/abs/1810.00826) +- TAGCN: [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/abs/1710.10370) +- SGC: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153) +- AGNN: [Attention-based Graph Neural Network for Semi-supervised Learning](https://arxiv.org/pdf/1803.03735.pdf) +- ChebNet: [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375) + +## Usage + +``` +python run.py [--gpu] --model MODEL_NAME --dataset DATASET_NAME [--self-loop] +``` + +The hyperparameters might not be the optimal, you could specify them manually in `conf.py`. diff --git a/examples/pytorch/model_zoo/citation_network/conf.py b/examples/pytorch/model_zoo/citation_network/conf.py new file mode 100644 index 000000000000..57297b86d419 --- /dev/null +++ b/examples/pytorch/model_zoo/citation_network/conf.py @@ -0,0 +1,56 @@ +import torch as th +import torch.nn.functional as F + +GCN_CONFIG = { + 'extra_args': [16, 1, F.relu, 0.5], + 'lr': 1e-2, + 'weight_decay': 5e-4, +} + +GAT_CONFIG = { + 'extra_args': [8, 1, [8] * 1 + [1], F.elu, 0.6, 0.6, 0.2, False], + 'lr': 0.005, + 'weight_decay': 5e-4, +} + +GRAPHSAGE_CONFIG = { + 'extra_args': [16, 1, F.relu, 0.5, 'gcn'], + 'lr': 1e-2, + 'weight_decay': 5e-4, +} + +APPNP_CONFIG = { + 'extra_args': [64, 1, F.relu, 0.5, 0.5, 0.1, 10], + 'lr': 1e-2, + 'weight_decay': 5e-4, +} + +TAGCN_CONFIG = { + 'extra_args': [16, 1, F.relu, 0.5], + 'lr': 1e-2, + 'weight_decay': 5e-4, +} + +AGNN_CONFIG = { + 'extra_args': [32, 2, 1.0, True, 0.5], + 'lr': 1e-2, + 'weight_decay': 5e-4, +} + +SGC_CONFIG = { + 'extra_args': [None, 2, False], + 'lr': 0.2, + 'weight_decay': 5e-6, +} + +GIN_CONFIG = { + 'extra_args': [16, 1, 0, True], + 'lr': 1e-2, + 'weight_decay': 5e-6, +} + +CHEBNET_CONFIG = { + 'extra_args': [16, 1, 3, True], + 'lr': 1e-2, + 'weight_decay': 5e-4, +} diff --git a/examples/pytorch/model_zoo/citation_network/models.py b/examples/pytorch/model_zoo/citation_network/models.py new file mode 100644 index 000000000000..1ec7de10a09e --- /dev/null +++ b/examples/pytorch/model_zoo/citation_network/models.py @@ -0,0 +1,320 @@ +import torch +import torch.nn as nn +from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv, GINConv,\ + APPNPConv, TAGConv, SGConv, AGNNConv, ChebConv + + +class GCN(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + activation, + dropout): + super(GCN, self).__init__() + self.g = g + self.layers = nn.ModuleList() + # input layer + self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) + # output layer + self.layers.append(GraphConv(n_hidden, n_classes)) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, features): + h = features + for i, layer in enumerate(self.layers): + if i != 0: + h = self.dropout(h) + h = layer(h, self.g) + return h + + +class GAT(nn.Module): + def __init__(self, + g, + in_dim, + num_classes, + num_hidden, + num_layers, + heads, + activation, + feat_drop, + attn_drop, + negative_slope, + residual): + super(GAT, self).__init__() + self.g = g + self.num_layers = num_layers + self.gat_layers = nn.ModuleList() + self.activation = activation + # input projection (no residual) + self.gat_layers.append(GATConv( + in_dim, num_hidden, heads[0], + feat_drop, attn_drop, negative_slope, False, self.activation)) + # hidden layers + for l in range(1, num_layers): + # due to multi-head, the in_dim = num_hidden * num_heads + self.gat_layers.append(GATConv( + num_hidden * heads[l-1], num_hidden, heads[l], + feat_drop, attn_drop, negative_slope, residual, self.activation)) + # output projection + self.gat_layers.append(GATConv( + num_hidden * heads[-2], num_classes, heads[-1], + feat_drop, attn_drop, negative_slope, residual, None)) + + def forward(self, inputs): + h = inputs + for l in range(self.num_layers): + h = self.gat_layers[l](h, self.g).flatten(1) + # output projection + logits = self.gat_layers[-1](h, self.g).mean(1) + return logits + + +class GraphSAGE(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + activation, + dropout, + aggregator_type): + super(GraphSAGE, self).__init__() + self.layers = nn.ModuleList() + self.g = g + + # input layer + self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) + # output layer + self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, feat_drop=dropout, activation=None)) # activation None + + def forward(self, features): + h = features + for layer in self.layers: + h = layer(h, self.g) + return h + + +class APPNP(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + activation, + feat_drop, + edge_drop, + alpha, + k): + super(APPNP, self).__init__() + self.g = g + self.layers = nn.ModuleList() + # input layer + self.layers.append(nn.Linear(in_feats, n_hidden)) + # hidden layers + for i in range(1, n_layers): + self.layers.append(nn.Linear(n_hidden, n_hidden)) + # output layer + self.layers.append(nn.Linear(n_hidden, n_classes)) + self.activation = activation + if feat_drop: + self.feat_drop = nn.Dropout(feat_drop) + else: + self.feat_drop = lambda x: x + self.propagate = APPNPConv(k, alpha, edge_drop) + self.reset_parameters() + + def reset_parameters(self): + for layer in self.layers: + layer.reset_parameters() + + def forward(self, features): + # prediction step + h = features + h = self.feat_drop(h) + h = self.activation(self.layers[0](h)) + for layer in self.layers[1:-1]: + h = self.activation(layer(h)) + h = self.layers[-1](self.feat_drop(h)) + # propagation step + h = self.propagate(h, self.g) + return h + + +class TAGCN(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + activation, + dropout): + super(TAGCN, self).__init__() + self.g = g + self.layers = nn.ModuleList() + # input layer + self.layers.append(TAGConv(in_feats, n_hidden, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.append(TAGConv(n_hidden, n_hidden, activation=activation)) + # output layer + self.layers.append(TAGConv(n_hidden, n_classes)) #activation=None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, features): + h = features + for i, layer in enumerate(self.layers): + if i != 0: + h = self.dropout(h) + h = layer(h, self.g) + return h + + +class AGNN(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + init_beta, + learn_beta, + dropout): + super(AGNN, self).__init__() + self.g = g + self.layers = nn.ModuleList( + [AGNNConv(init_beta, learn_beta) for _ in range(n_layers)] + ) + self.proj = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(in_feats, n_hidden), + nn.ReLU() + ) + self.cls = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(n_hidden, n_classes) + ) + + def forward(self, features): + h = self.proj(features) + for layer in self.layers: + h = layer(h, self.g) + return self.cls(h) + + +class SGC(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + k, + bias): + super(SGC, self).__init__() + self.g = g + self.net = SGConv(in_feats, + n_classes, + k=k, + cached=True, + bias=bias) + + def forward(self, features): + return self.net(features, self.g) + + +class GIN(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + init_eps, + learn_eps): + super(GIN, self).__init__() + self.g = g + self.layers = nn.ModuleList() + self.layers.append( + GINConv( + nn.Sequential( + nn.Dropout(0.6), + nn.Linear(in_feats, n_hidden), + nn.ReLU(), + ), + 'mean', + init_eps, + learn_eps + ) + ) + for i in range(n_layers - 1): + self.layers.append( + GINConv( + nn.Sequential( + nn.Dropout(0.6), + nn.Linear(n_hidden, n_hidden), + nn.ReLU() + ), + 'mean', + init_eps, + learn_eps + ) + ) + self.layers.append( + GINConv( + nn.Sequential( + nn.Dropout(0.6), + nn.Linear(n_hidden, n_classes), + ), + 'mean', + init_eps, + learn_eps + ) + ) + + def forward(self, features): + h = features + for layer in self.layers: + h = layer(h, self.g) + return h + +class ChebNet(nn.Module): + def __init__(self, + g, + in_feats, + n_classes, + n_hidden, + n_layers, + k, + bias): + super(ChebNet, self).__init__() + self.g = g + self.layers = nn.ModuleList() + self.layers.append( + ChebConv(in_feats, n_hidden, k, bias) + ) + for _ in range(n_layers - 1): + self.layers.append( + ChebConv(n_hidden, n_hidden, k, bias) + ) + + self.layers.append( + ChebConv(n_hidden, n_classes, k, bias) + ) + + def forward(self, features): + h = features + for layer in self.layers: + h = layer(h, self.g) + return h \ No newline at end of file diff --git a/examples/pytorch/model_zoo/citation_network/run.py b/examples/pytorch/model_zoo/citation_network/run.py new file mode 100644 index 000000000000..cc2648204c2d --- /dev/null +++ b/examples/pytorch/model_zoo/citation_network/run.py @@ -0,0 +1,150 @@ +import argparse, time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from models import * +from conf import * + + +def get_model_and_config(name): + name = name.lower() + if name == 'gcn': + return GCN, GCN_CONFIG + elif name == 'gat': + return GAT, GAT_CONFIG + elif name == 'graphsage': + return GraphSAGE, GRAPHSAGE_CONFIG + elif name == 'appnp': + return APPNP, APPNP_CONFIG + elif name == 'tagcn': + return TAGCN, TAGCN_CONFIG + elif name == 'agnn': + return AGNN, AGNN_CONFIG + elif name == 'sgc': + return SGC, SGC_CONFIG + elif name == 'gin': + return GIN, GIN_CONFIG + elif name == 'chebnet': + return ChebNet, CHEBNET_CONFIG + +def evaluate(model, features, labels, mask): + model.eval() + with torch.no_grad(): + logits = model(features) + logits = logits[mask] + labels = labels[mask] + _, indices = torch.max(logits, dim=1) + correct = torch.sum(indices == labels) + return correct.item() * 1.0 / len(labels) + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = torch.FloatTensor(data.features) + labels = torch.LongTensor(data.labels) + train_mask = torch.ByteTensor(data.train_mask) + val_mask = torch.ByteTensor(data.val_mask) + test_mask = torch.ByteTensor(data.test_mask) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().item(), + val_mask.sum().item(), + test_mask.sum().item())) + + if args.gpu < 0: + cuda = False + else: + cuda = True + torch.cuda.set_device(args.gpu) + features = features.cuda() + labels = labels.cuda() + train_mask = train_mask.cuda() + val_mask = val_mask.cuda() + test_mask = test_mask.cuda() + + # graph preprocess and calculate normalization factor + g = data.graph + # add self loop + if args.self_loop: + g.remove_edges_from(g.selfloop_edges()) + g.add_edges_from(zip(g.nodes(), g.nodes())) + g = DGLGraph(g) + n_edges = g.number_of_edges() + # normalization + degs = g.in_degrees().float() + norm = torch.pow(degs, -0.5) + norm[torch.isinf(norm)] = 0 + if cuda: + norm = norm.cuda() + g.ndata['norm'] = norm.unsqueeze(1) + + # create GCN model + GNN, config = get_model_and_config(args.model) + model = GNN(g, + in_feats, + n_classes, + *config['extra_args']) + + if cuda: + model.cuda() + + print(model) + + loss_fcn = torch.nn.CrossEntropyLoss() + + # use optimizer + optimizer = torch.optim.Adam(model.parameters(), + lr=config['lr'], + weight_decay=config['weight_decay']) + + # initialize graph + dur = [] + for epoch in range(200): + model.train() + if epoch >= 3: + t0 = time.time() + # forward + logits = model(features) + loss = loss_fcn(logits[train_mask], labels[train_mask]) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch >= 3: + dur.append(time.time() - t0) + + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), + acc, n_edges / np.mean(dur) / 1000)) + + print() + acc = evaluate(model, features, labels, test_mask) + print("Test Accuracy {:.4f}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Node classification on citation networks.') + register_data_args(parser) + parser.add_argument("--model", type=str, default='gcn', + help='model to use, available models are gcn, gat, graphsage, gin,' + 'appnp, tagcn, sgc, agnn') + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--self-loop", action='store_true', + help="graph self-loop (default=False)") + args = parser.parse_args() + print(args) + main(args) \ No newline at end of file diff --git a/examples/pytorch/sgc/sgc.py b/examples/pytorch/sgc/sgc.py index c75e4d13d3d1..c793e32bb9d5 100644 --- a/examples/pytorch/sgc/sgc.py +++ b/examples/pytorch/sgc/sgc.py @@ -13,40 +13,13 @@ import dgl.function as fn from dgl import DGLGraph from dgl.data import register_data_args, load_data +from dgl.nn.pytorch.conv import SGConv -class SGCLayer(nn.Module): - def __init__(self, - g, - h, - in_feats, - out_feats, - bias=False, - K=2): - super(SGCLayer, self).__init__() - self.g = g - self.weight = nn.Linear(in_feats, out_feats, bias=bias) - self.K = K - # precomputing message passing - for _ in range(self.K): - # normalization by square root of src degree - h = h * self.g.ndata['norm'] - self.g.ndata['h'] = h - self.g.update_all(fn.copy_src(src='h', out='m'), - fn.sum(msg='m', out='h')) - h = self.g.ndata.pop('h') - # normalization by square root of dst degree - h = h * self.g.ndata['norm'] - # store precomputed result into a cached variable - self.cached_h = h - - def forward(self, mask): - h = self.weight(self.cached_h[mask]) - return h - -def evaluate(model, features, labels, mask): + +def evaluate(model, g, features, labels, mask): model.eval() with torch.no_grad(): - logits = model(mask) # only compute the evaluation set + logits = model(features, g)[mask] # only compute the evaluation set labels = labels[mask] _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) @@ -90,21 +63,13 @@ def main(args): n_edges = g.number_of_edges() # add self loop g.add_edges(g.nodes(), g.nodes()) - # normalization - degs = g.in_degrees().float() - norm = torch.pow(degs, -0.5) - norm[torch.isinf(norm)] = 0 - if cuda: - norm = norm.cuda() - g.ndata['norm'] = norm.unsqueeze(1) # create SGC model - model = SGCLayer(g, - features, - in_feats, - n_classes, - args.bias, - K=2) + model = SGConv(in_feats, + n_classes, + k=2, + cached=True, + bias=args.bias) if cuda: model.cuda() loss_fcn = torch.nn.CrossEntropyLoss() @@ -121,8 +86,8 @@ def main(args): if epoch >= 3: t0 = time.time() # forward - logits = model(train_mask) # only compute the train set - loss = loss_fcn(logits, labels[train_mask]) + logits = model(features, g) # only compute the train set + loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() @@ -131,13 +96,13 @@ def main(args): if epoch >= 3: dur.append(time.time() - t0) - acc = evaluate(model, features, labels, val_mask) + acc = evaluate(model, g, features, labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), acc, n_edges / np.mean(dur) / 1000)) print() - acc = evaluate(model, features, labels, test_mask) + acc = evaluate(model, g, features, labels, test_mask) print("Test Accuracy {:.4f}".format(acc)) diff --git a/examples/pytorch/sgc/sgc_reddit.py b/examples/pytorch/sgc/sgc_reddit.py index eac67e6045b7..2f458e6a9b03 100644 --- a/examples/pytorch/sgc/sgc_reddit.py +++ b/examples/pytorch/sgc/sgc_reddit.py @@ -13,38 +13,15 @@ import dgl.function as fn from dgl import DGLGraph from dgl.data import register_data_args, load_data +from dgl.nn.pytorch.conv import SGConv -class SGCLayer(nn.Module): - def __init__(self,g,h,in_feats,out_feats,K=2): - super(SGCLayer, self).__init__() - self.g = g - self.weight = nn.Linear(in_feats, out_feats, bias=True) - self.K = K - # precomputing message passing - start = time.perf_counter() - for _ in range(self.K): - # normalization by square root of src degree - h = h * self.g.ndata['norm'] - self.g.ndata['h'] = h - self.g.update_all(fn.copy_src(src='h', out='m'), - fn.sum(msg='m', out='h')) - h = self.g.ndata.pop('h') - # normalization by square root of dst degree - h = h * self.g.ndata['norm'] - h = (h-h.mean(0))/h.std(0) - precompute_elapse = time.perf_counter()-start - print("Precompute Time(s): {:.4f}".format(precompute_elapse)) - # store precomputed result into a cached variable - self.cached_h = h +def normalize(h): + return (h-h.mean(0))/h.std(0) - def forward(self, mask): - h = self.weight(self.cached_h[mask]) - return h - -def evaluate(model, features, labels, mask): +def evaluate(model, features, graph, labels, mask): model.eval() with torch.no_grad(): - logits = model(mask) # only compute the evaluation set + logits = model(features, graph)[mask] # only compute the evaluation set labels = labels[mask] _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) @@ -85,7 +62,6 @@ def main(args): test_mask = test_mask.cuda() # graph preprocess and calculate normalization factor - start = time.perf_counter() g = DGLGraph(data.graph) n_edges = g.number_of_edges() # normalization @@ -94,14 +70,11 @@ def main(args): norm[torch.isinf(norm)] = 0 if cuda: norm = norm.cuda() g.ndata['norm'] = norm.unsqueeze(1) - preprocess_elapse = time.perf_counter()-start - print("Preprocessing Time: {:.4f}".format(preprocess_elapse)) # create SGC model - model = SGCLayer(g,features,in_feats,n_classes,K=2) - - if cuda: model.cuda() - loss_fcn = torch.nn.CrossEntropyLoss() + model = SGConv(in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize) + if args.gpu >= 0: + model = model.cuda() # use optimizer optimizer = torch.optim.LBFGS(model.parameters()) @@ -109,22 +82,17 @@ def main(args): # define loss closure def closure(): optimizer.zero_grad() - output = model(train_mask) + output = model(features, g)[train_mask] loss_train = F.cross_entropy(output, labels[train_mask]) loss_train.backward() return loss_train # initialize graph - dur = [] - start = time.perf_counter() for epoch in range(args.n_epochs): model.train() - logits = model(train_mask) # only compute the train set - loss = optimizer.step(closure) + optimizer.step(closure) - train_elapse = time.perf_counter()-start - print("Train epoch {} | Train Time(s) {:.4f}".format(epoch, train_elapse)) - acc = evaluate(model, features, labels, test_mask) + acc = evaluate(model, features, g, labels, test_mask) print("Test Accuracy {:.4f}".format(acc)) diff --git a/examples/pytorch/tagcn/tagcn.py b/examples/pytorch/tagcn/tagcn.py index 22904b541055..2bfcff77a251 100644 --- a/examples/pytorch/tagcn/tagcn.py +++ b/examples/pytorch/tagcn/tagcn.py @@ -7,7 +7,7 @@ """ import torch import torch.nn as nn -from dgl.nn.pytorch.conv import TGConv +from dgl.nn.pytorch.conv import TAGConv class TAGCN(nn.Module): def __init__(self, @@ -22,12 +22,12 @@ def __init__(self, self.g = g self.layers = nn.ModuleList() # input layer - self.layers.append(TGConv(in_feats, n_hidden, activation=activation)) + self.layers.append(TAGConv(in_feats, n_hidden, activation=activation)) # hidden layers for i in range(n_layers - 1): - self.layers.append(TGConv(n_hidden, n_hidden, activation=activation)) + self.layers.append(TAGConv(n_hidden, n_hidden, activation=activation)) # output layer - self.layers.append(TGConv(n_hidden, n_classes)) #activation=None + self.layers.append(TAGConv(n_hidden, n_classes)) #activation=None self.dropout = nn.Dropout(p=dropout) def forward(self, features): diff --git a/python/dgl/graph_index.py b/python/dgl/graph_index.py index bc5c57d54fc1..ad4f74c48afd 100644 --- a/python/dgl/graph_index.py +++ b/python/dgl/graph_index.py @@ -479,7 +479,7 @@ def in_degrees(self, v): Returns ------- - int + tensor The in degree array. """ v_array = v.todgltensor() @@ -510,7 +510,7 @@ def out_degrees(self, v): Returns ------- - int + tensor The out degree array. """ v_array = v.todgltensor() diff --git a/python/dgl/nn/mxnet/conv.py b/python/dgl/nn/mxnet/conv.py index 641322676a68..9acca0ecc309 100644 --- a/python/dgl/nn/mxnet/conv.py +++ b/python/dgl/nn/mxnet/conv.py @@ -1,5 +1,5 @@ """MXNet modules for graph convolutions.""" -# pylint: disable= no-member, arguments-differ +# pylint: disable= no-member, arguments-differ, invalid-name import math import mxnet as mx from mxnet import gluon, nd diff --git a/python/dgl/nn/mxnet/glob.py b/python/dgl/nn/mxnet/glob.py index 62eef9f40678..f3b89e775e03 100644 --- a/python/dgl/nn/mxnet/glob.py +++ b/python/dgl/nn/mxnet/glob.py @@ -1,5 +1,5 @@ """MXNet modules for graph global pooling.""" -# pylint: disable= no-member, arguments-differ, C0103, W0235 +# pylint: disable= no-member, arguments-differ, invalid-name, W0235 from mxnet import gluon, nd from mxnet.gluon import nn diff --git a/python/dgl/nn/pytorch/conv.py b/python/dgl/nn/pytorch/conv.py index 5acc901fd7dc..3ff769b59f9a 100644 --- a/python/dgl/nn/pytorch/conv.py +++ b/python/dgl/nn/pytorch/conv.py @@ -1,14 +1,36 @@ """Torch modules for graph convolutions.""" -# pylint: disable= no-member, arguments-differ +# pylint: disable= no-member, arguments-differ, invalid-name import torch as th from torch import nn from torch.nn import init +import torch.nn.functional as F from . import utils from ... import function as fn +from ...batched_graph import broadcast_nodes +from ...transform import laplacian_lambda_max +from .softmax import edge_softmax + +__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv', + 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', + 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', + 'DenseChebConv'] + +# pylint: disable=W0235 +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive. + (Identity has already been supported by PyTorch 1.2, we will directly + import torch.nn.Identity in the future) + """ + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + """Return input""" + return x -__all__ = ['GraphConv', 'TGConv', 'RelGraphConv'] +# pylint: enable=W0235 class GraphConv(nn.Module): r"""Apply graph convolution over an input signal. @@ -41,9 +63,9 @@ class GraphConv(nn.Module): Parameters ---------- in_feats : int - Number of input features. + Input feature size. out_feats : int - Number of output features. + Output feature size. norm : bool, optional If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. bias : bool, optional @@ -90,10 +112,10 @@ def forward(self, feat, graph): Notes ----- - * Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional - dimensions, :math:`N` is the number of nodes. - * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are - the same shape as the input. + * Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional + dimensions, :math:`N` is the number of nodes. + * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are + the same shape as the input. Parameters ---------- @@ -109,7 +131,7 @@ def forward(self, feat, graph): """ graph = graph.local_var() if self._norm: - norm = th.pow(graph.in_degrees().float(), -0.5) + norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) shp = norm.shape + (1,) * (feat.dim() - 1) norm = th.reshape(norm, shp).to(feat.device) feat = feat * norm @@ -150,8 +172,125 @@ def extra_repr(self): summary += ', activation={_activation}' return summary.format(**self.__dict__) -class TGConv(nn.Module): - r"""Apply Topology Adaptive Graph Convolutional Network + +class GATConv(nn.Module): + r"""Apply `Graph Attention Network `__ + over an input signal. + + .. math:: + h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} + + where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and + node :math:`j`: + + .. math:: + \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l}) + + e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h^{I} \| W h^{j}]\right) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + num_heads : int + Number of heads in Multi-Head Attention. + feat_drop : float, optional + Dropout rate on feature, defaults: ``0``. + attn_drop : float, optional + Dropout rate on attention weight, defaults: ``0``. + negative_slope : float, optional + LeakyReLU angle of negative slope. + residual : bool, optional + If True, use residual connection. + activation : callable activation function/layer or None, optional. + If not None, applies an activation function to the updated node features. + Default: ``None``. + """ + def __init__(self, + in_feats, + out_feats, + num_heads, + feat_drop=0., + attn_drop=0., + negative_slope=0.2, + residual=False, + activation=None): + super(GATConv, self).__init__() + self._num_heads = num_heads + self._in_feats = in_feats + self._out_feats = out_feats + self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False) + self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + if residual: + if in_feats != out_feats: + self.res_fc = nn.Linear(in_feats, num_heads * out_feats, bias=False) + else: + self.res_fc = Identity() + else: + self.register_buffer('res_fc', None) + self.reset_parameters() + self.activation = activation + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + gain = nn.init.calculate_gain('relu') + nn.init.xavier_normal_(self.fc.weight, gain=gain) + nn.init.xavier_normal_(self.attn_l, gain=gain) + nn.init.xavier_normal_(self.attn_r, gain=gain) + if isinstance(self.res_fc, nn.Linear): + nn.init.xavier_normal_(self.res_fc.weight, gain=gain) + + def forward(self, feat, graph): + r"""Compute graph attention network layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + """ + graph = graph.local_var() + h = self.feat_drop(feat) + feat = self.fc(h).view(-1, self._num_heads, self._out_feats) + el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1) + er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1) + graph.ndata.update({'ft': feat, 'el': el, 'er': er}) + # compute edge attention + graph.apply_edges(fn.u_add_v('el', 'er', 'e')) + e = self.leaky_relu(graph.edata.pop('e')) + # compute softmax + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) + # message passing + graph.update_all(fn.u_mul_e('ft', 'a', 'm'), + fn.sum('m', 'ft')) + rst = graph.ndata['ft'] + # residual + if self.res_fc is not None: + resval = self.res_fc(h).view(h.shape[0], -1, self._out_feats) + rst = rst + resval + # activation + if self.activation: + rst = self.activation(rst) + return rst + + +class TAGConv(nn.Module): + r"""Topology Adaptive Graph Convolutional layer from paper `Topology + Adaptive Graph Convolutional Networks `__. .. math:: \mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} @@ -163,9 +302,9 @@ class TGConv(nn.Module): Parameters ---------- in_feats : int - Number of input features. + Input feature size. out_feats : int - Number of output features. + Output feature size. k: int, optional Number of hops :math: `k`. (default: 3) bias: bool, optional @@ -185,7 +324,7 @@ def __init__(self, k=2, bias=True, activation=None): - super(TGConv, self).__init__() + super(TAGConv, self).__init__() self._in_feats = in_feats self._out_feats = out_feats self._k = k @@ -196,26 +335,29 @@ def __init__(self, def reset_parameters(self): """Reinitialize learnable parameters.""" - self.lin.reset_parameters() + gain = nn.init.calculate_gain('relu') + nn.init.xavier_normal_(self.lin.weight, gain=gain) def forward(self, feat, graph): - r"""Compute graph convolution + r"""Compute topology adaptive graph convolution. Parameters ---------- feat : torch.Tensor - The input feature + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. graph : DGLGraph The graph. Returns ------- torch.Tensor - The output feature + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. """ graph = graph.local_var() - norm = th.pow(graph.in_degrees().float(), -0.5) + norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) shp = norm.shape + (1,) * (feat.dim() - 1) norm = th.reshape(norm, shp).to(feat.device) @@ -380,7 +522,7 @@ def bdd_message_func(self, edges): return {'msg': msg} def forward(self, g, x, etypes, norm=None): - """Forward computation + """ Forward computation Parameters ---------- @@ -388,13 +530,13 @@ def forward(self, g, x, etypes, norm=None): The graph. x : torch.Tensor Input node features. Could be either - - (|V|, D) dense tensor - - (|V|,) int64 vector, representing the categorical values of each - node. We then treat the input feature as an one-hot encoding feature. + * :math:`(|V|, D)` dense tensor + * :math:`(|V|,)` int64 vector, representing the categorical values of each + node. We then treat the input feature as an one-hot encoding feature. etypes : torch.Tensor - Edge type tensor. Shape: (|E|,) + Edge type tensor. Shape: :math:`(|E|,)` norm : torch.Tensor - Optional edge normalizer tensor. Shape: (|E|, 1) + Optional edge normalizer tensor. Shape: :math:`(|E|, 1)` Returns ------- @@ -408,10 +550,8 @@ def forward(self, g, x, etypes, norm=None): g.edata['norm'] = norm if self.self_loop: loop_message = utils.matmul_maybe_select(x, self.loop_weight) - # message passing g.update_all(self.message_func, fn.sum(msg='msg', out='h')) - # apply bias and activation node_repr = g.ndata['h'] if self.bias: @@ -421,5 +561,1117 @@ def forward(self, g, x, etypes, norm=None): if self.activation: node_repr = self.activation(node_repr) node_repr = self.dropout(node_repr) - return node_repr + + +class SAGEConv(nn.Module): + r"""GraphSAGE layer from paper `Inductive Representation Learning on + Large Graphs `__. + + .. math:: + h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate} + \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) + + h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat} + (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right) + + h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l}) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + feat_drop : float + Dropout rate on features, default: ``0``. + aggregator_type : str + Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``). + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization oto the updated node features. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + """ + def __init__(self, + in_feats, + out_feats, + aggregator_type, + feat_drop=0., + bias=True, + norm=None, + activation=None): + super(SAGEConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._aggre_type = aggregator_type + self.norm = norm + self.feat_drop = nn.Dropout(feat_drop) + self.activation = activation + # aggregator type: mean/pool/lstm/gcn + if aggregator_type == 'pool': + self.fc_pool = nn.Linear(in_feats, in_feats) + if aggregator_type == 'lstm': + self.lstm = nn.LSTM(in_feats, in_feats, batch_first=True) + if aggregator_type != 'gcn': + self.fc_self = nn.Linear(in_feats, out_feats, bias=bias) + self.fc_neigh = nn.Linear(in_feats, out_feats, bias=bias) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + gain = nn.init.calculate_gain('relu') + if self._aggre_type == 'pool': + nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) + if self._aggre_type == 'lstm': + self.lstm.reset_parameters() + if self._aggre_type != 'gcn': + nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) + nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) + + def _lstm_reducer(self, nodes): + """LSTM reducer + NOTE(zihao): lstm reducer with default schedule (degree bucketing) + is slow, we could accelerate this with degree padding in the future. + """ + m = nodes.mailbox['m'] # (B, L, D) + batch_size = m.shape[0] + h = (m.new_zeros((1, batch_size, self._in_feats)), + m.new_zeros((1, batch_size, self._in_feats))) + _, (rst, _) = self.lstm(m, h) + return {'neigh': rst.squeeze(0)} + + def forward(self, feat, graph): + r"""Compute GraphSAGE layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + graph = graph.local_var() + feat = self.feat_drop(feat) + h_self = feat + if self._aggre_type == 'mean': + graph.ndata['h'] = feat + graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) + h_neigh = graph.ndata['neigh'] + elif self._aggre_type == 'gcn': + graph.ndata['h'] = feat + graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) + # divide in_degrees + degs = graph.in_degrees().float() + degs = degs.to(feat.device) + h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1) + elif self._aggre_type == 'pool': + graph.ndata['h'] = F.relu(self.fc_pool(feat)) + graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) + h_neigh = graph.ndata['neigh'] + elif self._aggre_type == 'lstm': + graph.ndata['h'] = feat + graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) + h_neigh = graph.ndata['neigh'] + else: + raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) + # GraphSAGE GCN does not require fc_self. + if self._aggre_type == 'gcn': + rst = self.fc_neigh(h_neigh) + else: + rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) + # activation + if self.activation is not None: + rst = self.activation(rst) + # normalization + if self.norm is not None: + rst = self.norm(rst) + return rst + + +class GatedGraphConv(nn.Module): + r"""Gated Graph Convolution layer from paper `Gated Graph Sequence + Neural Networks `__. + + .. math:: + h_{i}^{0} & = [ x_i \| \mathbf{0} ] + + a_{i}^{t} & = \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t} + + h_{i}^{t+1} & = \mathrm{GRU}(a_{i}^{t}, h_{i}^{t}) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + n_steps : int + Number of recurrent steps. + n_etypes : int + Number of edge types. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + n_steps, + n_etypes, + bias=True): + super(GatedGraphConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._n_steps = n_steps + self.edge_embed = nn.Embedding(n_etypes, out_feats * out_feats) + self.gru = nn.GRUCell(out_feats, out_feats, bias=bias) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + gain = init.calculate_gain('relu') + self.gru.reset_parameters() + init.xavier_normal_(self.edge_embed.weight, gain=gain) + + def forward(self, feat, etypes, graph): + """Compute Gated Graph Convolution layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`N` + is the number of nodes of the graph and :math:`D_{in}` is the + input feature size. + etypes : torch.LongTensor + The edge type tensor of shape :math:`(E,)` where :math:`E` is + the number of edges of the graph. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is the output feature size. + """ + graph = graph.local_var() + zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1])) + feat = th.cat([feat, zero_pad], -1) + # NOTE(zihao): there is still room to optimize, we may do kernel fusion + # for such operations in the future. + graph.edata['w'] = self.edge_embed(etypes).view(-1, self._out_feats, self._out_feats) + for _ in range(self._n_steps): + graph.ndata['h'] = feat.unsqueeze(-1) # (N, D, 1) + graph.update_all(fn.u_mul_e('h', 'w', 'm'), + fn.sum('m', 'a')) + a = graph.ndata.pop('a').sum(dim=1) # (N, D) + feat = self.gru(a, feat) + return feat + + +class GMMConv(nn.Module): + r"""The Gaussian Mixture Model Convolution layer from `Geometric Deep + Learning on Graphs and Manifolds using Mixture Model CNNs + `__. + + .. math:: + h_i^{l+1} & = \mathrm{aggregate}\left(\left\{\frac{1}{K} + \sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right) + + w_k(u) & = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right) + + Parameters + ---------- + in_feats : int + Number of input features. + out_feats : int + Number of output features. + dim : int + Dimensionality of pseudo-coordinte. + n_kernels : int + Number of kernels :math:`K`. + aggregator_type : str + Aggregator type (``sum``, ``mean``, ``max``). + residual : bool + If True, use residual connection inside this layer. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + dim, + n_kernels, + aggregator_type, + residual=True, + bias=True): + super(GMMConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._dim = dim + self._n_kernels = n_kernels + if aggregator_type == 'sum': + self._reducer = fn.sum + elif aggregator_type == 'mean': + self._reducer = fn.mean + elif aggregator_type == 'max': + self._reducer = fn.max + else: + raise KeyError("Aggregator type {} not recognized.".format(aggregator_type)) + + self.mu = nn.Parameter(th.Tensor(n_kernels, dim)) + self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim)) + self.fc = nn.Linear(in_feats, n_kernels * out_feats, bias=False) + if residual: + if in_feats != out_feats: + self.res_fc = nn.Linear(in_feats, out_feats, bias=False) + else: + self.res_fc = Identity() + else: + self.register_buffer('res_fc', None) + + if bias: + self.bias = nn.Parameter(th.Tensor(out_feats)) + else: + self.register_buffer('bias', None) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + gain = init.calculate_gain('relu') + init.xavier_normal_(self.fc.weight, gain=gain) + if isinstance(self.res_fc, nn.Linear): + init.xavier_normal_(self.res_fc.weight, gain=gain) + init.normal_(self.mu.data, 0, 0.1) + init.normal_(self.inv_sigma.data, 1, 0.1) + if self.bias is not None: + init.zeros_(self.bias.data) + + def forward(self, feat, pseudo, graph): + """Compute Gaussian Mixture Model Convolution layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`N` + is the number of nodes of the graph and :math:`D_{in}` is the + input feature size. + pseudo : torch.Tensor + The pseudo coordinate tensor of shape :math:`(E, D_{u})` where + :math:`E` is the number of edges of the graph and :math:`D_{u}` + is the dimensionality of pseudo coordinate. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is the output feature size. + """ + graph = graph.local_var() + graph.ndata['h'] = self.fc(feat).view(-1, self._n_kernels, self._out_feats) + E = graph.number_of_edges() + # compute gaussian weight + gaussian = -0.5 * ((pseudo.view(E, 1, self._dim) - + self.mu.view(1, self._n_kernels, self._dim)) ** 2) + gaussian = gaussian * (self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2) + gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1) + graph.edata['w'] = gaussian + graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) + rst = graph.ndata['h'].sum(1) + # residual connection + if self.res_fc is not None: + rst = rst + self.res_fc(feat) + # bias + if self.bias is not None: + rst = rst + self.bias + return rst + + +class GINConv(nn.Module): + r"""Graph Isomorphism Network layer from paper `How Powerful are Graph + Neural Networks? `__. + + .. math:: + h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} + + \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) + \right\}\right)\right) + + Parameters + ---------- + apply_func : callable activation function/layer or None + If not None, apply this function to the updated node feature, + the :math:`f_\Theta` in the formula. + aggregator_type : str + Aggregator type to use (``sum``, ``max`` or ``mean``). + init_eps : float, optional + Initial :math:`\epsilon` value, default: ``0``. + learn_eps : bool, optional + If True, :math:`\epsilon` will be a learnable parameter. + """ + def __init__(self, + apply_func, + aggregator_type, + init_eps=0, + learn_eps=False): + super(GINConv, self).__init__() + self.apply_func = apply_func + if aggregator_type == 'sum': + self._reducer = fn.sum + elif aggregator_type == 'max': + self._reducer = fn.max + elif aggregator_type == 'mean': + self._reducer = fn.mean + else: + raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) + # to specify whether eps is trainable or not. + if learn_eps: + self.eps = th.nn.Parameter(th.FloatTensor([init_eps])) + else: + self.register_buffer('eps', th.FloatTensor([init_eps])) + + def forward(self, feat, graph): + r"""Compute Graph Isomorphism Network layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D)` where :math:`D` + could be any positive integer, :math:`N` is the number + of nodes. If ``apply_func`` is not None, :math:`D` should + fit the input dimensionality requirement of ``apply_func``. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where + :math:`D_{out}` is the output dimensionality of ``apply_func``. + If ``apply_func`` is None, :math:`D_{out}` should be the same + as input dimensionality. + """ + graph = graph.local_var() + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) + rst = (1 + self.eps) * feat + graph.ndata['neigh'] + if self.apply_func is not None: + rst = self.apply_func(rst) + return rst + + +class ChebConv(nn.Module): + r"""Chebyshev Spectral Graph Convolution layer from paper `Convolutional + Neural Networks on Graphs with Fast Localized Spectral Filtering + `__. + + .. math:: + h_i^{l+1} &= \sum_{k=0}^{K-1} W^{k, l}z_i^{k, l} + + Z^{0, l} &= H^{l} + + Z^{1, l} &= \hat{L} \cdot H^{l} + + Z^{k, l} &= 2 \cdot \hat{L} \cdot Z^{k-1, l} - Z^{k-2, l} + + \hat{L} &= 2\left(I - \hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2}\right)/\lambda_{max} - I + + Parameters + ---------- + in_feats: int + Number of input features. + out_feats: int + Number of output features. + k : int + Chebyshev filter size. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + k, + bias=True): + super(ChebConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self.fc = nn.ModuleList([ + nn.Linear(in_feats, out_feats, bias=False) for _ in range(k) + ]) + self._k = k + if bias: + self.bias = nn.Parameter(th.Tensor(out_feats)) + else: + self.register_buffer('bias', None) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + if self.bias is not None: + init.zeros_(self.bias) + for module in self.fc.modules(): + if isinstance(module, nn.Linear): + init.xavier_normal_(module.weight, init.calculate_gain('relu')) + if module.bias is not None: + init.zeros_(module.bias) + + def forward(self, feat, graph, lambda_max=None): + r"""Compute ChebNet layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + with graph.local_scope(): + norm = th.pow( + graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device) + if lambda_max is None: + lambda_max = laplacian_lambda_max(graph) + lambda_max = th.Tensor(lambda_max).to(feat.device) + if lambda_max.dim() < 1: + lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) + # broadcast from (B, 1) to (N, 1) + lambda_max = broadcast_nodes(graph, lambda_max) + # T0(X) + + Tx_0 = feat + rst = self.fc[0](Tx_0) + # T1(X) + if self._k > 1: + graph.ndata['h'] = Tx_0 * norm + graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h = graph.ndata.pop('h') * norm + # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I + # = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I + Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1) + rst = rst + self.fc[1](Tx_1) + # Ti(x), i = 2...k + for i in range(2, self._k): + graph.ndata['h'] = Tx_1 * norm + graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h = graph.ndata.pop('h') * norm + # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2) + # = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) + + # (4 / lambda_max - 2) Tx_(k-1) - + # Tx_(k-2) + Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0 + rst = rst + self.fc[i](Tx_2) + Tx_1, Tx_0 = Tx_2, Tx_1 + # add bias + if self.bias is not None: + rst = rst + self.bias + return rst + + +class SGConv(nn.Module): + r"""Simplifying Graph Convolution layer from paper `Simplifying Graph + Convolutional Networks `__. + + .. math:: + H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l} + + Parameters + ---------- + in_feats : int + Number of input features. + out_feats : int + Number of output features. + k : int + Number of hops :math:`K`. Defaults:``1``. + cached : bool + If True, the module would cache + + .. math:: + (\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta + + at the first forward call. This parameter should only be set to + ``True`` in Transductive Learning setting. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization oto the updated node features. + """ + def __init__(self, + in_feats, + out_feats, + k=1, + cached=False, + bias=True, + norm=None): + super(SGConv, self).__init__() + self.fc = nn.Linear(in_feats, out_feats, bias=bias) + self._cached = cached + self._cached_h = None + self._k = k + self.norm = norm + + def forward(self, feat, graph): + r"""Compute Simplifying Graph Convolution layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + + Notes + ----- + If ``cache`` is se to True, ``feat`` and ``graph`` should not change during + training, or you will get wrong results. + """ + graph = graph.local_var() + if self._cached_h is not None: + feat = self._cached_h + else: + # compute normalization + degs = graph.in_degrees().float().clamp(min=1) + norm = th.pow(degs, -0.5) + norm[th.isinf(norm)] = 0 + norm = norm.to(feat.device).unsqueeze(1) + # compute (D^-1 A D) X + for _ in range(self._k): + feat = feat * norm + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), + fn.sum('m', 'h')) + feat = graph.ndata.pop('h') + feat = feat * norm + + if self.norm is not None: + feat = self.norm(feat) + + # cache feature + if self._cached: + self._cached_h = feat + return self.fc(feat) + + +class NNConv(nn.Module): + r"""Graph Convolution layer introduced in `Neural Message Passing + for Quantum Chemistry `__. + + .. math:: + h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{ + f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + edge_func : callable activation function/layer + Maps each edge feature to a vector of shape + ``(in_feats * out_feats)`` as weight to compute + messages. + Also is the :math:`f_\Theta` in the formula. + aggregator_type : str + Aggregator type to use (``sum``, ``mean`` or ``max``). + residual : bool, optional + If True, use residual connection. Default: ``False``. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + edge_func, + aggregator_type, + residual=False, + bias=True): + super(NNConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self.edge_nn = edge_func + if aggregator_type == 'sum': + self.reducer = fn.sum + elif aggregator_type == 'mean': + self.reducer = fn.mean + elif aggregator_type == 'max': + self.reducer = fn.max + else: + raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type)) + self._aggre_type = aggregator_type + if residual: + if in_feats != out_feats: + self.res_fc = nn.Linear(in_feats, out_feats, bias=False) + else: + self.res_fc = Identity() + else: + self.register_buffer('res_fc', None) + if bias: + self.bias = nn.Parameter(th.Tensor(out_feats)) + else: + self.register_buffer('bias', None) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + gain = init.calculate_gain('relu') + if self.bias is not None: + nn.init.zeros_(self.bias) + if isinstance(self.res_fc, nn.Linear): + nn.init.xavier_normal_(self.res_fc.weight, gain=gain) + + def forward(self, feat, efeat, graph): + r"""Compute MPNN Graph Convolution layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`N` + is the number of nodes of the graph and :math:`D_{in}` is the + input feature size. + efeat : torch.Tensor + The edge feature of shape :math:`(N, *)`, should fit the input + shape requirement of ``edge_nn``. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is the output feature size. + """ + graph = graph.local_var() + # (n, d_in, 1) + graph.ndata['h'] = feat.unsqueeze(-1) + # (n, d_in, d_out) + graph.edata['w'] = self.edge_nn(efeat).view(-1, self._in_feats, self._out_feats) + # (n, d_in, d_out) + graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) + rst = graph.ndata.pop('neigh').sum(dim=1) # (n, d_out) + # residual connection + if self.res_fc is not None: + rst = rst + self.res_fc(feat) + # bias + if self.bias is not None: + rst = rst + self.bias + return rst + + +class APPNPConv(nn.Module): + r"""Approximate Personalized Propagation of Neural Predictions + layer from paper `Predict then Propagate: Graph Neural Networks + meet Personalized PageRank `__. + + .. math:: + H^{0} & = X + + H^{t+1} & = (1-\alpha)\left(\hat{D}^{-1/2} + \hat{A} \hat{D}^{-1/2} H^{t} + \alpha H^{0}\right) + + Parameters + ---------- + k : int + Number of iterations :math:`K`. + alpha : float + The teleport probability :math:`\alpha`. + edge_drop : float, optional + Dropout rate on edges that controls the + messages received by each node. Default: ``0``. + """ + def __init__(self, + k, + alpha, + edge_drop=0.): + super(APPNPConv, self).__init__() + self._k = k + self._alpha = alpha + self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity() + + def forward(self, feat, graph): + r"""Compute APPNP layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, *)` :math:`N` is the + number of nodes, and :math:`*` could be of any shape. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, *)` where :math:`*` + should be the same as input shape. + """ + graph = graph.local_var() + norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) + norm = norm.unsqueeze(-1).to(feat.device) + feat_0 = feat + for _ in range(self._k): + # normalization by src + feat = feat * norm + graph.ndata['h'] = feat + graph.edata['w'] = self.edge_drop( + th.ones(graph.number_of_edges(), 1).to(feat.device)) + graph.update_all(fn.u_mul_e('h', 'w', 'm'), + fn.sum('m', 'h')) + feat = graph.ndata.pop('h') + # normalization by dst + feat = feat * norm + feat = (1 - self._alpha) * feat + self._alpha * feat_0 + return feat + + +class AGNNConv(nn.Module): + r"""Attention-based Graph Neural Network layer from paper `Attention-based + Graph Neural Network for Semi-Supervised Learning + `__. + + .. math:: + H^{l+1} = P H^{l} + + where :math:`P` is computed as: + + .. math:: + P_{ij} = \mathrm{softmax}_i ( \beta \cdot \cos(h_i^l, h_j^l)) + + Parameters + ---------- + init_beta : float, optional + The :math:`\beta` in the formula. + learn_beta : bool, optional + If True, :math:`\beta` will be learnable parameter. + """ + def __init__(self, + init_beta=1., + learn_beta=True): + super(AGNNConv, self).__init__() + if learn_beta: + self.beta = nn.Parameter(th.Tensor([init_beta])) + else: + self.register_buffer('beta', th.Tensor([init_beta])) + + def forward(self, feat, graph): + r"""Compute AGNN layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, *)` :math:`N` is the + number of nodes, and :math:`*` could be of any shape. + graph : DGLGraph + The graph. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, *)` where :math:`*` + should be the same as input shape. + """ + graph = graph.local_var() + graph.ndata['h'] = feat + graph.ndata['norm_h'] = F.normalize(feat, p=2, dim=-1) + # compute cosine distance + graph.apply_edges(fn.u_mul_v('norm_h', 'norm_h', 'cos')) + cos = graph.edata.pop('cos').sum(-1) + e = self.beta * cos + graph.edata['p'] = edge_softmax(graph, e) + graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) + return graph.ndata.pop('h') + + +class DenseGraphConv(nn.Module): + """Graph Convolutional Network layer where the graph structure + is given by an adjacency matrix. + We recommend user to use this module when inducing graph convolution + on dense graphs / k-hop graphs. + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + norm : bool + If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + + See also + -------- + GraphConv + """ + def __init__(self, + in_feats, + out_feats, + norm=True, + bias=True, + activation=None): + super(DenseGraphConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._norm = norm + self.weight = nn.Parameter(th.Tensor(in_feats, out_feats)) + if bias: + self.bias = nn.Parameter(th.Tensor(out_feats)) + else: + self.register_buffer('bias', None) + + self.reset_parameters() + self._activation = activation + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + init.xavier_uniform_(self.weight) + if self.bias is not None: + init.zeros_(self.bias) + + def forward(self, feat, adj): + r"""Compute (Dense) Graph Convolution layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + adj : torch.Tensor + The adjacency matrix of the graph to apply Graph Convolution on, + should be of shape :math:`(N, N)`, where a row represents the destination + and a column represents the source. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + adj = adj.float().to(feat.device) + if self._norm: + in_degrees = adj.sum(dim=1) + norm = th.pow(in_degrees, -0.5) + shp = norm.shape + (1,) * (feat.dim() - 1) + norm = th.reshape(norm, shp).to(feat.device) + feat = feat * norm + + if self._in_feats > self._out_feats: + # mult W first to reduce the feature size for aggregation. + feat = th.matmul(feat, self.weight) + rst = adj @ feat + else: + # aggregate first then mult W + rst = adj @ feat + rst = th.matmul(rst, self.weight) + + if self._norm: + rst = rst * norm + + if self.bias is not None: + rst = rst + self.bias + + if self._activation is not None: + rst = self._activation(rst) + + return rst + + +class DenseSAGEConv(nn.Module): + """GraphSAGE layer where the graph structure is given by an + adjacency matrix. + We recommend to use this module when inducing GraphSAGE operations + on dense graphs / k-hop graphs. + + Note that we only support gcn aggregator in DenseSAGEConv. + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + feat_drop : float, optional + Dropout rate on features. Default: 0. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization oto the updated node features. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + + See also + -------- + SAGEConv + """ + def __init__(self, + in_feats, + out_feats, + feat_drop=0., + bias=True, + norm=None, + activation=None): + super(DenseSAGEConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._norm = norm + self.feat_drop = nn.Dropout(feat_drop) + self.activation = activation + self.fc = nn.Linear(in_feats, out_feats, bias=bias) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + gain = nn.init.calculate_gain('relu') + nn.init.xavier_uniform_(self.fc.weight, gain=gain) + + def forward(self, feat, adj): + r"""Compute (Dense) Graph SAGE layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + adj : torch.Tensor + The adjacency matrix of the graph to apply Graph Convolution on, + should be of shape :math:`(N, N)`, where a row represents the destination + and a column represents the source. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + adj = adj.float().to(feat.device) + feat = self.feat_drop(feat) + in_degrees = adj.sum(dim=1).unsqueeze(-1) + h_neigh = (adj @ feat + feat) / (in_degrees + 1) + rst = self.fc(h_neigh) + # activation + if self.activation is not None: + rst = self.activation(rst) + # normalization + if self._norm is not None: + rst = self._norm(rst) + + return rst + + +class DenseChebConv(nn.Module): + r"""Chebyshev Spectral Graph Convolution layer from paper `Convolutional + Neural Networks on Graphs with Fast Localized Spectral Filtering + `__. + + We recommend to use this module when inducing ChebConv operations on dense + graphs / k-hop graphs. + + Parameters + ---------- + in_feats: int + Number of input features. + out_feats: int + Number of output features. + k : int + Chebyshev filter size. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + + See also + -------- + ChebConv + """ + def __init__(self, + in_feats, + out_feats, + k, + bias=True): + super(DenseChebConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._k = k + self.W = nn.Parameter(th.Tensor(k, in_feats, out_feats)) + if bias: + self.bias = nn.Parameter(th.Tensor(out_feats)) + else: + self.register_buffer('bias', None) + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + if self.bias is not None: + init.zeros_(self.bias) + for i in range(self._k): + init.xavier_normal_(self.W[i], init.calculate_gain('relu')) + + def forward(self, feat, adj): + r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer. + + Parameters + ---------- + feat : torch.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + adj : torch.Tensor + The adjacency matrix of the graph to apply Graph Convolution on, + should be of shape :math:`(N, N)`, where a row represents the destination + and a column represents the source. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + A = adj.to(feat) + num_nodes = A.shape[0] + + in_degree = 1 / A.sum(dim=1).clamp(min=1).sqrt() + D_invsqrt = th.diag(in_degree) + I = th.eye(num_nodes).to(A) + L = I - D_invsqrt @ A @ D_invsqrt + + lambda_ = th.eig(L)[0][:, 0] + lambda_max = lambda_.max() + L_hat = 2 * L / lambda_max - I + + Z = [th.eye(num_nodes).to(A)] + for i in range(1, self._k): + if i == 1: + Z.append(L_hat) + else: + Z.append(2 * L_hat @ Z[-1] - Z[-2]) + + Zs = th.stack(Z, 0) # (k, n, n) + + Zh = (Zs @ feat.unsqueeze(0) @ self.W) + Zh = Zh.sum(0) + + if self.bias is not None: + Zh = Zh + self.bias + return Zh diff --git a/python/dgl/nn/pytorch/glob.py b/python/dgl/nn/pytorch/glob.py index 26fd52444515..b925faf8f839 100644 --- a/python/dgl/nn/pytorch/glob.py +++ b/python/dgl/nn/pytorch/glob.py @@ -1,5 +1,5 @@ """Torch modules for graph global pooling.""" -# pylint: disable= no-member, arguments-differ, C0103, W0235 +# pylint: disable= no-member, arguments-differ, invalid-name, W0235 import torch as th import torch.nn as nn import numpy as np @@ -178,17 +178,6 @@ def __init__(self, gate_nn, feat_nn=None): super(GlobalAttentionPooling, self).__init__() self.gate_nn = gate_nn self.feat_nn = feat_nn - self.reset_parameters() - - def reset_parameters(self): - """Reinitialize learnable parameters.""" - for p in self.gate_nn.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - if self.feat_nn: - for p in self.feat_nn.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) def forward(self, feat, graph): r"""Compute global attention pooling. diff --git a/python/dgl/transform.py b/python/dgl/transform.py index a29b53595a0d..c364b8d21e18 100644 --- a/python/dgl/transform.py +++ b/python/dgl/transform.py @@ -1,9 +1,16 @@ -"""Module for graph transformation methods.""" +"""Module for graph transformation utilities.""" + +import numpy as np +from scipy import sparse from ._ffi.function import _init_api from .graph import DGLGraph -from .batched_graph import BatchedDGLGraph +from .graph_index import from_coo +from .batched_graph import BatchedDGLGraph, unbatch +from .backend import asnumpy, tensor + -__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected'] +__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected', + 'laplacian_lambda_max'] def line_graph(g, backtracking=True, shared=False): @@ -12,6 +19,7 @@ def line_graph(g, backtracking=True, shared=False): Parameters ---------- g : dgl.DGLGraph + The input graph. backtracking : bool, optional Whether the returned line graph is backtracking. shared : bool, optional @@ -26,6 +34,88 @@ def line_graph(g, backtracking=True, shared=False): node_frame = g._edge_frame if shared else None return DGLGraph(graph_data, node_frame) +def khop_adj(g, k): + """Return the matrix of :math:`A^k` where :math:`A` is the adjacency matrix of :math:`g`, + where a row represents the destination and a column represents the source. + + Parameters + ---------- + g : dgl.DGLGraph + The input graph. + k : int + The :math:`k` in :math:`A^k`. + + Returns + ------- + tensor + The returned tensor, dtype is ``np.float32``. + + Examples + -------- + + >>> import dgl + >>> g = dgl.DGLGraph() + >>> g.add_nodes(5) + >>> g.add_edges([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0]) + >>> dgl.khop_adj(g, 1) + tensor([[1., 0., 0., 0., 1.], + [1., 1., 0., 0., 0.], + [0., 1., 1., 0., 0.], + [0., 0., 1., 1., 0.], + [0., 0., 0., 1., 1.]]) + >>> dgl.khop_adj(g, 3) + tensor([[1., 0., 1., 3., 3.], + [3., 1., 0., 1., 3.], + [3., 3., 1., 0., 1.], + [1., 3., 3., 1., 0.], + [0., 1., 3., 3., 1.]]) + """ + adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k + return tensor(adj_k.todense().astype(np.float32)) + +def khop_graph(g, k): + """Return the graph that includes all :math:`k`-hop neighbors of the given graph as edges. + The adjacency matrix of the returned graph is :math:`A^k` + (where :math:`A` is the adjacency matrix of :math:`g`). + + Parameters + ---------- + g : dgl.DGLGraph + The input graph. + k : int + The :math:`k` in `k`-hop graph. + + Returns + ------- + dgl.DGLGraph + The returned ``DGLGraph``. + + Examples + -------- + + >>> import dgl + >>> g = dgl.DGLGraph() + >>> g.add_nodes(5) + >>> g.add_edges([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0]) + >>> dgl.khop_graph(g, 1) + DGLGraph(num_nodes=5, num_edges=10, + ndata_schemes={} + edata_schemes={}) + >>> dgl.khop_graph(g, 3) + DGLGraph(num_nodes=5, num_edges=40, + ndata_schemes={} + edata_schemes={}) + """ + n = g.number_of_nodes() + adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k + adj_k = adj_k.tocoo() + multiplicity = adj_k.data + row = np.repeat(adj_k.row, multiplicity) + col = np.repeat(adj_k.col, multiplicity) + # TODO(zihao): we should support creating multi-graph from scipy sparse matrix + # in the future. + return DGLGraph(from_coo(n, row, col, True, True)) + def reverse(g, share_ndata=False, share_edata=False): """Return the reverse of a graph @@ -46,6 +136,7 @@ def reverse(g, share_ndata=False, share_edata=False): Parameters ---------- g : dgl.DGLGraph + The input graph. share_ndata: bool, optional If True, the original graph and the reversed graph share memory for node attributes. Otherwise the reversed graph will not be initialized with node attributes. @@ -169,4 +260,49 @@ def to_bidirected(g, readonly=True): newgidx = _CAPI_DGLToBidirectedMutableGraph(g._graph) return DGLGraph(newgidx) +def laplacian_lambda_max(g): + """Return the largest eigenvalue of the normalized symmetric laplacian of g. + + The eigenvalue of the normalized symmetric of any graph is less than or equal to 2, + ref: https://en.wikipedia.org/wiki/Laplacian_matrix#Properties + + Parameters + ---------- + g : DGLGraph or BatchedDGLGraph + The input graph, it should be an undirected graph. + + Returns + ------- + list : + * If the input g is a DGLGraph, the returned value would be + a list with one element, indicating the largest eigenvalue of g. + * If the input g is a BatchedDGLGraph, the returned value would + be a list, where the i-th item indicates the largest eigenvalue + of i-th graph in g. + + Examples + -------- + + >>> import dgl + >>> g = dgl.DGLGraph() + >>> g.add_nodes(5) + >>> g.add_edges([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], [1, 2, 3, 4, 0, 4, 0, 1, 2, 3]) + >>> dgl.laplacian_lambda_max(g) + [1.809016994374948] + """ + if isinstance(g, BatchedDGLGraph): + g_arr = unbatch(g) + else: + g_arr = [g] + + rst = [] + for g_i in g_arr: + n = g_i.number_of_nodes() + adj = g_i.adjacency_matrix_scipy(return_edge_ids=False).astype(float) + norm = sparse.diags(asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float) + laplacian = sparse.eye(n) - norm * adj * norm + rst.append(sparse.linalg.eigs(laplacian, 1, which='LM', + return_eigenvectors=False)[0].real) + return rst + _init_api("dgl.transform") diff --git a/tests/backend/backend_unittest.py b/tests/backend/backend_unittest.py index 3c7e90542421..d80854bdcdf2 100644 --- a/tests/backend/backend_unittest.py +++ b/tests/backend/backend_unittest.py @@ -110,6 +110,11 @@ def min(x, dim): def prod(x, dim): """Computes the prod of array elements over given axes""" pass + +def matmul(a, b): + """Compute Matrix Multiplication between a and b""" + pass + ############################################################################### # Tensor functions used *only* on index tensor # ---------------- diff --git a/tests/backend/mxnet/__init__.py b/tests/backend/mxnet/__init__.py index cd678699b8ca..4191c05c9f55 100644 --- a/tests/backend/mxnet/__init__.py +++ b/tests/backend/mxnet/__init__.py @@ -83,6 +83,9 @@ def min(x, dim): def prod(x, dim): return x.prod(dim) +def matmul(a, b): + return nd.dot(a, b) + record_grad = autograd.record diff --git a/tests/backend/pytorch/__init__.py b/tests/backend/pytorch/__init__.py index c7ffa72f655b..01cc2a2e3c3e 100644 --- a/tests/backend/pytorch/__init__.py +++ b/tests/backend/pytorch/__init__.py @@ -79,6 +79,9 @@ def min(x, dim): def prod(x, dim): return x.prod(dim) +def matmul(a, b): + return a @ b + class record_grad(object): def __init__(self): pass diff --git a/tests/compute/test_transform.py b/tests/compute/test_transform.py index f67d8d1ea660..2be70bce68d3 100644 --- a/tests/compute/test_transform.py +++ b/tests/compute/test_transform.py @@ -112,6 +112,56 @@ def _test(in_readonly, out_readonly): _test(False, True) _test(False, False) +def test_khop_graph(): + N = 20 + feat = F.randn((N, 5)) + g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3)) + for k in range(4): + g_k = dgl.khop_graph(g, k) + # use original graph to do message passing for k times. + g.ndata['h'] = feat + for _ in range(k): + g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h_0 = g.ndata.pop('h') + # use k-hop graph to do message passing for one time. + g_k.ndata['h'] = feat + g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h_1 = g_k.ndata.pop('h') + assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3) + +def test_khop_adj(): + N = 20 + feat = F.randn((N, 5)) + g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3)) + for k in range(3): + adj = F.tensor(dgl.khop_adj(g, k)) + # use original graph to do message passing for k times. + g.ndata['h'] = feat + for _ in range(k): + g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h_0 = g.ndata.pop('h') + # use k-hop adj to do message passing for one time. + h_1 = F.matmul(adj, feat) + assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3) + +def test_laplacian_lambda_max(): + N = 20 + eps = 1e-6 + # test DGLGraph + g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3)) + l_max = dgl.laplacian_lambda_max(g) + assert (l_max[0] < 2 + eps) + # test BatchedDGLGraph + N_arr = [20, 30, 10, 12] + bg = dgl.batch([ + dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3)) + for N in N_arr + ]) + l_max_arr = dgl.laplacian_lambda_max(bg) + assert len(l_max_arr) == len(N_arr) + for l_max in l_max_arr: + assert l_max < 2 + eps + if __name__ == '__main__': test_line_graph() test_no_backtracking() @@ -119,3 +169,6 @@ def _test(in_readonly, out_readonly): test_reverse_shared_frames() test_simple_graph() test_bidirected_graph() + test_khop_adj() + test_khop_graph() + test_laplacian_lambda_max() diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 317befa657fe..58b5c8698c55 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -20,7 +20,7 @@ def test_graph_conv(): conv = nn.GraphConv(5, 2, norm=False, bias=True) if F.gpu_ctx(): - conv.cuda() + conv = conv.to(ctx) print(conv) # test#1: basic h0 = F.ones((3, 5)) @@ -37,7 +37,7 @@ def test_graph_conv(): conv = nn.GraphConv(5, 2) if F.gpu_ctx(): - conv.cuda() + conv = conv.to(ctx) # test#3: basic h0 = F.ones((3, 5)) h1 = conv(h0, g) @@ -51,7 +51,7 @@ def test_graph_conv(): conv = nn.GraphConv(5, 2) if F.gpu_ctx(): - conv.cuda() + conv = conv.to(ctx) # test#3: basic h0 = F.ones((3, 5)) h1 = conv(h0, g) @@ -81,15 +81,15 @@ def _S2AXWb(A, N, X, W, b): return Y + b -def test_tgconv(): +def test_tagconv(): g = dgl.DGLGraph(nx.path_graph(3)) ctx = F.ctx() adj = g.adjacency_matrix(ctx=ctx) norm = th.pow(g.in_degrees().float(), -0.5) - conv = nn.TGConv(5, 2, bias=True) + conv = nn.TAGConv(5, 2, bias=True) if F.gpu_ctx(): - conv.cuda() + conv = conv.to(ctx) print(conv) # test#1: basic @@ -102,27 +102,27 @@ def test_tgconv(): assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)) - conv = nn.TGConv(5, 2) + conv = nn.TAGConv(5, 2) if F.gpu_ctx(): - conv.cuda() + conv = conv.to(ctx) # test#2: basic h0 = F.ones((3, 5)) h1 = conv(h0, g) - assert len(g.ndata) == 0 - assert len(g.edata) == 0 + assert h1.shape[-1] == 2 - # test rest_parameters + # test reset_parameters old_weight = deepcopy(conv.lin.weight.data) conv.reset_parameters() new_weight = conv.lin.weight.data assert not F.allclose(old_weight, new_weight) def test_set2set(): + ctx = F.ctx() g = dgl.DGLGraph(nx.path_graph(10)) s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers if F.gpu_ctx(): - s2s.cuda() + s2s = s2s.to(ctx) print(s2s) # test#1: basic @@ -139,11 +139,12 @@ def test_set2set(): assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2 def test_glob_att_pool(): + ctx = F.ctx() g = dgl.DGLGraph(nx.path_graph(10)) gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10)) if F.gpu_ctx(): - gap.cuda() + gap = gap.to(ctx) print(gap) # test#1: basic @@ -158,6 +159,7 @@ def test_glob_att_pool(): assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2 def test_simple_pool(): + ctx = F.ctx() g = dgl.DGLGraph(nx.path_graph(15)) sum_pool = nn.SumPooling() @@ -168,6 +170,12 @@ def test_simple_pool(): # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) + if F.gpu_ctx(): + sum_pool = sum_pool.to(ctx) + avg_pool = avg_pool.to(ctx) + max_pool = max_pool.to(ctx) + sort_pool = sort_pool.to(ctx) + h0 = h0.to(ctx) h1 = sum_pool(h0, g) assert F.allclose(h1, F.sum(h0, 0)) h1 = avg_pool(h0, g) @@ -181,6 +189,8 @@ def test_simple_pool(): g_ = dgl.DGLGraph(nx.path_graph(5)) bg = dgl.batch([g, g_, g, g_, g]) h0 = F.randn((bg.number_of_nodes(), 5)) + if F.gpu_ctx(): + h0 = h0.to(ctx) h1 = sum_pool(h0, bg) truth = th.stack([F.sum(h0[:15], 0), @@ -210,15 +220,16 @@ def test_simple_pool(): assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2 def test_set_trans(): + ctx = F.ctx() g = dgl.DGLGraph(nx.path_graph(15)) st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab') st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3) st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4) if F.gpu_ctx(): - st_enc_0.cuda() - st_enc_1.cuda() - st_dec.cuda() + st_enc_0 = st_enc_0.to(ctx) + st_enc_1 = st_enc_1.to(ctx) + st_dec = st_dec.to(ctx) print(st_enc_0, st_enc_1, st_dec) # test#1: basic @@ -354,6 +365,207 @@ def test_rgcn(): h_new = rgc_basis(g, h, r) assert list(h_new.shape) == [100, O] +def test_gat_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + gat = nn.GATConv(5, 2, 4) + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + gat = gat.to(ctx) + feat = feat.to(ctx) + + h = gat(feat, g) + assert h.shape[-1] == 2 and h.shape[-2] == 4 + +def test_sage_conv(): + for aggre_type in ['mean', 'pool', 'gcn', 'lstm']: + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + sage = nn.SAGEConv(5, 10, aggre_type) + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + sage = sage.to(ctx) + feat = feat.to(ctx) + + h = sage(feat, g) + assert h.shape[-1] == 10 + +def test_sgc_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + # not cached + sgc = nn.SGConv(5, 10, 3) + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + sgc = sgc.to(ctx) + feat = feat.to(ctx) + + h = sgc(feat, g) + assert h.shape[-1] == 10 + + # cached + sgc = nn.SGConv(5, 10, 3, True) + + if F.gpu_ctx(): + sgc = sgc.to(ctx) + + h_0 = sgc(feat, g) + h_1 = sgc(feat + 1, g) + assert F.allclose(h_0, h_1) + assert h_0.shape[-1] == 10 + +def test_appnp_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + appnp = nn.APPNPConv(10, 0.1) + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + appnp = appnp.to(ctx) + feat = feat.to(ctx) + + h = appnp(feat, g) + assert h.shape[-1] == 5 + +def test_gin_conv(): + for aggregator_type in ['mean', 'max', 'sum']: + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + gin = nn.GINConv( + th.nn.Linear(5, 12), + aggregator_type + ) + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + gin = gin.to(ctx) + feat = feat.to(ctx) + + h = gin(feat, g) + assert h.shape[-1] == 12 + +def test_agnn_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + agnn = nn.AGNNConv(1) + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + agnn = agnn.to(ctx) + feat = feat.to(ctx) + + h = agnn(feat, g) + assert h.shape[-1] == 5 + +def test_gated_graph_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + ggconv = nn.GatedGraphConv(5, 10, 5, 3) + etypes = th.arange(g.number_of_edges()) % 3 + feat = F.randn((100, 5)) + + if F.gpu_ctx(): + ggconv = ggconv.to(ctx) + feat = feat.to(ctx) + etypes = etypes.to(ctx) + + h = ggconv(feat, etypes, g) + # current we only do shape check + assert h.shape[-1] == 10 + +def test_nn_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + edge_func = th.nn.Linear(4, 5 * 10) + nnconv = nn.NNConv(5, 10, edge_func, 'mean') + feat = F.randn((100, 5)) + efeat = F.randn((g.number_of_edges(), 4)) + + if F.gpu_ctx(): + nnconv = nnconv.to(ctx) + feat = feat.to(ctx) + efeat = efeat.to(ctx) + + h = nnconv(feat, efeat, g) + # currently we only do shape check + assert h.shape[-1] == 10 + +def test_gmm_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean') + feat = F.randn((100, 5)) + pseudo = F.randn((g.number_of_edges(), 3)) + + if F.gpu_ctx(): + gmmconv = gmmconv.to(ctx) + feat = feat.to(ctx) + pseudo = pseudo.to(ctx) + + h = gmmconv(feat, pseudo, g) + # currently we only do shape check + assert h.shape[-1] == 10 + +def test_dense_graph_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + adj = g.adjacency_matrix(ctx=ctx).to_dense() + conv = nn.GraphConv(5, 2, norm=False, bias=True) + dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True) + dense_conv.weight.data = conv.weight.data + dense_conv.bias.data = conv.bias.data + feat = F.randn((100, 5)) + if F.gpu_ctx(): + conv = conv.to(ctx) + dense_conv = dense_conv.to(ctx) + feat = feat.to(ctx) + + out_conv = conv(feat, g) + out_dense_conv = dense_conv(feat, adj) + assert F.allclose(out_conv, out_dense_conv) + +def test_dense_sage_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + adj = g.adjacency_matrix(ctx=ctx).to_dense() + sage = nn.SAGEConv(5, 2, 'gcn',) + dense_sage = nn.DenseSAGEConv(5, 2) + dense_sage.fc.weight.data = sage.fc_neigh.weight.data + dense_sage.fc.bias.data = sage.fc_neigh.bias.data + feat = F.randn((100, 5)) + if F.gpu_ctx(): + sage = sage.to(ctx) + dense_sage = dense_sage.to(ctx) + feat = feat.to(ctx) + + out_sage = sage(feat, g) + out_dense_sage = dense_sage(feat, adj) + assert F.allclose(out_sage, out_dense_sage) + +def test_dense_cheb_conv(): + for k in range(1, 4): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + adj = g.adjacency_matrix(ctx=ctx).to_dense() + cheb = nn.ChebConv(5, 2, k) + dense_cheb = nn.DenseChebConv(5, 2, k) + for i in range(len(cheb.fc)): + dense_cheb.W.data[i] = cheb.fc[i].weight.data.t() + if cheb.bias is not None: + dense_cheb.bias.data = cheb.bias.data + feat = F.randn((100, 5)) + if F.gpu_ctx(): + cheb = cheb.to(ctx) + dense_cheb = dense_cheb.to(ctx) + feat = feat.to(ctx) + + out_cheb = cheb(feat, g) + out_dense_cheb = dense_cheb(feat, adj) + assert F.allclose(out_cheb, out_dense_cheb) + if __name__ == '__main__': test_graph_conv() test_edge_softmax() @@ -362,3 +574,17 @@ def test_rgcn(): test_simple_pool() test_set_trans() test_rgcn() + test_tagconv() + test_gat_conv() + test_sage_conv() + test_sgc_conv() + test_appnp_conv() + test_gin_conv() + test_agnn_conv() + test_gated_graph_conv() + test_nn_conv() + test_gmm_conv() + test_dense_graph_conv() + test_dense_sage_conv() + test_dense_cheb_conv() +