Skip to content

Commit

Permalink
[NN] Fix GCN module (dmlc#99)
Browse files Browse the repository at this point in the history
1. Update `examples/pytorch/gcn` and `python/dgl/nn/pytorch` based on the latest APIs
2. Add full support for dropout in `examples/pytorch/gcn` and `python/dgl/nn/pytorch`
3. Rename `GCN` class in `python/dgl/nn/pytorch` to be `GraphConvolutionLayer` class
4. Make node field an argument that can be configured by users in GraphConvolutionLayer

Note that adjacency normalization has not been supported yet in the examples.
  • Loading branch information
mufeili authored Oct 27, 2018
1 parent d0ea98b commit 2758c24
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 40 deletions.
20 changes: 11 additions & 9 deletions examples/pytorch/gcn/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
Graph Convolutional Networks (GCN)
============

Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)
- Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is
implemented with Tensorflow for the paper.

The folder contains two different implementations using DGL.

Expand All @@ -14,28 +15,29 @@ Defining the model on only one node and edge makes it hard to fully utilize GPUs
```python
def gcn_msg(src, edge):
# src is a tensor of shape (B, D). B is the number of edges being batched.
return src
return {'m' : src['h']}
```
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument:
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension for the `msgs` argument,
which for example can correspond to the neighbors of the nodes:
```python
def gcn_reduce(node, msgs):
# The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch;
# deg is the number of messages; D is the message tensor dimension. DGL gaurantees
# that all the nodes in a batch have the same in-degrees (through "degree-bucketing").
# Reduce on the second dimension is equal to sum up all the in-coming messages.
return torch.sum(msgs, 1)
return {'h' : torch.sum(msgs['m'], 1)}
```
* The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN.

Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching.
Triggering message passing is also similar.
```python
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)`
self.g.update_all(gcn_msg, gcn_reduce, layer)`
```

Batched GCN with spMV optimization (gcn_spmv.py)
-----------
Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows:
```python
self.g.update_all('from_src', 'sum', layer, batchable=True)
self.g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h'), layer)
```
Here, `'from_src'` and `'sum'` are the builtin message and reduce function.
Here, `'fn.copy_src'` and `'fn.sum'` are the builtin message and reduce functions that perform the same operations as `gcn_msg` and `gcn_reduce` in gcn.py.
18 changes: 14 additions & 4 deletions examples/pytorch/gcn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

Expand All @@ -24,13 +23,15 @@ def gcn_reduce(node, msgs):
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()

self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation

def forward(self, node):
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)

return {'h' : h}

class GCN(nn.Module):
Expand All @@ -44,27 +45,36 @@ def __init__(self,
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout

if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.

# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])

# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))

# output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes))

def forward(self, features):
self.g.set_n_repr({'h' : features})

for layer in self.layers:
# apply dropout
if self.dropout:
g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h')

def main(args):
# load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args)

features = torch.FloatTensor(data.features)
Expand Down
20 changes: 15 additions & 5 deletions examples/pytorch/gcn/gcn_spmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,23 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()

self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation

def forward(self, node):
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)
return {'h' : h}

return {'h': h}

class GCN(nn.Module):
def __init__(self,
Expand All @@ -39,29 +40,38 @@ def __init__(self,
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout

if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.

# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])

# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))

# output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes))

def forward(self, features):
self.g.set_n_repr({'h' : features})

for layer in self.layers:
# apply dropout
if self.dropout:
g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
return self.g.pop_n_repr('h')

def main(args):
# load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args)

features = torch.FloatTensor(data.features)
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .gcn import GCN
from .gcn import GraphConvolutionLayer
59 changes: 38 additions & 21 deletions python/dgl/nn/pytorch/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,61 @@
from ... import function as fn
from ...base import ALL, is_all


class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
def __init__(self, node_field, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()

self.node_field = node_field

self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
self.attribute = None

def forward(self, node):
h = self.linear(node['accum'])
h = self.linear(node[self.node_field])
if self.activation:
h = self.activation(h)
if self.attribute:
return {self.attribute: h}
else:
return h

class GCN(nn.Module):
return {self.node_field: h}

class GraphConvolutionLayer(nn.Module):
"""Single graph convolution layer as in https://arxiv.org/abs/1609.02907."""
def __init__(self,
node_field,
in_feats,
out_feats,
activation,
dropout=0):
super(GCN, self).__init__()
self.dropout = dropout
"""
node_filed: hashable keys for node features, e.g. 'h'
msg_field: hashable keys for message features, e.g. 'm'. In GCN, this is
just AH, where A is the adjacency matrix and H is current node features.
"""
super(GraphConvolutionLayer, self).__init__()

self.node_field = node_field

if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.

# input layer
self.update_func = NodeUpdateModule(in_feats, out_feats, activation)
self.update_func = NodeUpdateModule(node_field, in_feats, out_feats,
activation)

def forward(self, g, u=ALL, v=ALL):
if self.dropout:
g.apply_nodes(u, apply_node_func=
lambda node: {self.node_field: self.dropout(node[self.node_field])})

def forward(self, g, u=ALL, v=ALL, attribute=None):
if is_all(u) and is_all(v):
g.update_all(fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.update_all(fn.copy_src(src=self.node_field, out='m'),
fn.sum(msg='m', out=self.node_field),
self.update_func)
else:
g.send_and_recv(u, v,
fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.pop_n_repr('accum')
fn.copy_src(src=self.node_field, out='m'),
fn.sum(msg='m', out=self.node_field),
self.update_func)
return g

0 comments on commit 2758c24

Please sign in to comment.