diff --git a/examples/pytorch/gcn/README.md b/examples/pytorch/gcn/README.md index 78a93fabe272..b6d0a2f04069 100644 --- a/examples/pytorch/gcn/README.md +++ b/examples/pytorch/gcn/README.md @@ -1,8 +1,9 @@ Graph Convolutional Networks (GCN) ============ -Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) -Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn) +- Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) +- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is +implemented with Tensorflow for the paper. The folder contains two different implementations using DGL. @@ -14,28 +15,29 @@ Defining the model on only one node and edge makes it hard to fully utilize GPUs ```python def gcn_msg(src, edge): # src is a tensor of shape (B, D). B is the number of edges being batched. - return src + return {'m' : src['h']} ``` -* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument: +* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension for the `msgs` argument, +which for example can correspond to the neighbors of the nodes: ```python def gcn_reduce(node, msgs): # The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch; # deg is the number of messages; D is the message tensor dimension. DGL gaurantees # that all the nodes in a batch have the same in-degrees (through "degree-bucketing"). # Reduce on the second dimension is equal to sum up all the in-coming messages. - return torch.sum(msgs, 1) + return {'h' : torch.sum(msgs['m'], 1)} ``` * The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN. -Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching. +Triggering message passing is also similar. ```python -self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)` +self.g.update_all(gcn_msg, gcn_reduce, layer)` ``` Batched GCN with spMV optimization (gcn_spmv.py) ----------- Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows: ```python -self.g.update_all('from_src', 'sum', layer, batchable=True) +self.g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h'), layer) ``` -Here, `'from_src'` and `'sum'` are the builtin message and reduce function. +Here, `'fn.copy_src'` and `'fn.sum'` are the builtin message and reduce functions that perform the same operations as `gcn_msg` and `gcn_reduce` in gcn.py. diff --git a/examples/pytorch/gcn/gcn.py b/examples/pytorch/gcn/gcn.py index ac24b73c79e5..4c26c88c738c 100644 --- a/examples/pytorch/gcn/gcn.py +++ b/examples/pytorch/gcn/gcn.py @@ -11,7 +11,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import dgl from dgl import DGLGraph from dgl.data import register_data_args, load_data @@ -24,6 +23,7 @@ def gcn_reduce(node, msgs): class NodeApplyModule(nn.Module): def __init__(self, in_feats, out_feats, activation=None): super(NodeApplyModule, self).__init__() + self.linear = nn.Linear(in_feats, out_feats) self.activation = activation @@ -31,6 +31,7 @@ def forward(self, node): h = self.linear(node['h']) if self.activation: h = self.activation(h) + return {'h' : h} class GCN(nn.Module): @@ -44,27 +45,36 @@ def __init__(self, dropout): super(GCN, self).__init__() self.g = g - self.dropout = dropout + + if dropout: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = 0. + # input layer self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)]) + # hidden layers for i in range(n_layers - 1): self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) + # output layer self.layers.append(NodeApplyModule(n_hidden, n_classes)) def forward(self, features): self.g.set_n_repr({'h' : features}) + for layer in self.layers: # apply dropout if self.dropout: - g.apply_nodes(apply_node_func= - lambda node: F.dropout(node['h'], p=self.dropout)) + self.g.apply_nodes(apply_node_func= + lambda node: {'h': self.dropout(node['h'])}) self.g.update_all(gcn_msg, gcn_reduce, layer) return self.g.pop_n_repr('h') def main(args): # load and preprocess dataset + # Todo: adjacency normalization data = load_data(args) features = torch.FloatTensor(data.features) diff --git a/examples/pytorch/gcn/gcn_spmv.py b/examples/pytorch/gcn/gcn_spmv.py index 4feab69cb5d1..295a345f7aa9 100644 --- a/examples/pytorch/gcn/gcn_spmv.py +++ b/examples/pytorch/gcn/gcn_spmv.py @@ -11,7 +11,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import dgl import dgl.function as fn from dgl import DGLGraph from dgl.data import register_data_args, load_data @@ -19,6 +18,7 @@ class NodeApplyModule(nn.Module): def __init__(self, in_feats, out_feats, activation=None): super(NodeApplyModule, self).__init__() + self.linear = nn.Linear(in_feats, out_feats) self.activation = activation @@ -26,7 +26,8 @@ def forward(self, node): h = self.linear(node['h']) if self.activation: h = self.activation(h) - return {'h' : h} + + return {'h': h} class GCN(nn.Module): def __init__(self, @@ -39,22 +40,30 @@ def __init__(self, dropout): super(GCN, self).__init__() self.g = g - self.dropout = dropout + + if dropout: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = 0. + # input layer self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)]) + # hidden layers for i in range(n_layers - 1): self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) + # output layer self.layers.append(NodeApplyModule(n_hidden, n_classes)) def forward(self, features): self.g.set_n_repr({'h' : features}) + for layer in self.layers: # apply dropout if self.dropout: - g.apply_nodes(apply_node_func= - lambda node: F.dropout(node['h'], p=self.dropout)) + self.g.apply_nodes(apply_node_func= + lambda node: {'h': self.dropout(node['h'])}) self.g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h'), layer) @@ -62,6 +71,7 @@ def forward(self, features): def main(args): # load and preprocess dataset + # Todo: adjacency normalization data = load_data(args) features = torch.FloatTensor(data.features) diff --git a/python/dgl/nn/pytorch/__init__.py b/python/dgl/nn/pytorch/__init__.py index e490b75d0dc6..ef157a9ea8ae 100644 --- a/python/dgl/nn/pytorch/__init__.py +++ b/python/dgl/nn/pytorch/__init__.py @@ -1 +1 @@ -from .gcn import GCN +from .gcn import GraphConvolutionLayer diff --git a/python/dgl/nn/pytorch/gcn.py b/python/dgl/nn/pytorch/gcn.py index e2733dfc5966..cddb68e49209 100644 --- a/python/dgl/nn/pytorch/gcn.py +++ b/python/dgl/nn/pytorch/gcn.py @@ -10,44 +10,61 @@ from ... import function as fn from ...base import ALL, is_all + class NodeUpdateModule(nn.Module): - def __init__(self, in_feats, out_feats, activation=None): + def __init__(self, node_field, in_feats, out_feats, activation=None): super(NodeUpdateModule, self).__init__() + + self.node_field = node_field + self.linear = nn.Linear(in_feats, out_feats) self.activation = activation - self.attribute = None def forward(self, node): - h = self.linear(node['accum']) + h = self.linear(node[self.node_field]) if self.activation: h = self.activation(h) - if self.attribute: - return {self.attribute: h} - else: - return h -class GCN(nn.Module): + return {self.node_field: h} + +class GraphConvolutionLayer(nn.Module): + """Single graph convolution layer as in https://arxiv.org/abs/1609.02907.""" def __init__(self, + node_field, in_feats, out_feats, activation, dropout=0): - super(GCN, self).__init__() - self.dropout = dropout + """ + node_filed: hashable keys for node features, e.g. 'h' + msg_field: hashable keys for message features, e.g. 'm'. In GCN, this is + just AH, where A is the adjacency matrix and H is current node features. + """ + super(GraphConvolutionLayer, self).__init__() + + self.node_field = node_field + + if dropout: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = 0. + # input layer - self.update_func = NodeUpdateModule(in_feats, out_feats, activation) + self.update_func = NodeUpdateModule(node_field, in_feats, out_feats, + activation) + + def forward(self, g, u=ALL, v=ALL): + if self.dropout: + g.apply_nodes(u, apply_node_func= + lambda node: {self.node_field: self.dropout(node[self.node_field])}) - def forward(self, g, u=ALL, v=ALL, attribute=None): if is_all(u) and is_all(v): - g.update_all(fn.copy_src(src=attribute), - fn.sum(out='accum'), - self.update_func, - batchable=True) + g.update_all(fn.copy_src(src=self.node_field, out='m'), + fn.sum(msg='m', out=self.node_field), + self.update_func) else: g.send_and_recv(u, v, - fn.copy_src(src=attribute), - fn.sum(out='accum'), - self.update_func, - batchable=True) - g.pop_n_repr('accum') + fn.copy_src(src=self.node_field, out='m'), + fn.sum(msg='m', out=self.node_field), + self.update_func) return g