Skip to content

Commit

Permalink
gcn runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
jermainewang committed Oct 4, 2018
1 parent 2be55fb commit fde4f58
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 212 deletions.
38 changes: 2 additions & 36 deletions examples/pytorch/gcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,9 @@ Graph Convolutional Networks (GCN)
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)

The folder contains three different implementations using DGL.
The folder contains two different implementations using DGL.

Naive GCN (gcn.py)
-------
The model is defined in the finest granularity (aka on *one* edge and *one* node).

* The message function `gcn_msg` computes the message for one edge. It simply returns the `h` representation of the source node.
```python
def gcn_msg(src, edge):
# src['h'] is a tensor of shape (D,). D is the feature length.
return src['h']
```
* The reduce function `gcn_reduce` accumulates the incoming messages for one node. The `msgs` argument is a list of all the messages. In GCN, the incoming messages are summed up.
```python
def gcn_reduce(node, msgs):
# msgs is a list of in-coming messages.
return sum(msgs)
```
* The update function `NodeUpdateModule` computes the new new node representation `h` using non-linear transformation on the reduced messages.
```python
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation

def forward(self, node, accum):
# accum is a tensor of shape (D,).
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return {'h' : h}
```

After defining the functions on each node/edge, the message passing is triggered by calling `update_all` on the DGLGraph object (in GCN module).

Batched GCN (gcn_batch.py)
Batched GCN (gcn.py)
-----------
Defining the model on only one node and edge makes it hard to fully utilize GPUs. As a result, we allow users to define model on a *batch of* nodes and edges.

Expand Down
52 changes: 25 additions & 27 deletions examples/pytorch/gcn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with batch processing
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

def gcn_msg(src, edge):
return src['h']
return src

def gcn_reduce(node, msgs):
return {'h' : sum(msgs)}
return torch.sum(msgs, 1)

class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
Expand All @@ -25,22 +28,22 @@ def __init__(self, in_feats, out_feats, activation=None):
self.activation = activation

def forward(self, node):
h = self.linear(node['h'])
h = self.linear(node)
if self.activation:
h = self.activation(h)
return {'h' : h}
return h

class GCN(nn.Module):
def __init__(self,
nx_graph,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = DGLGraph(nx_graph)
self.g = g
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
Expand All @@ -50,31 +53,24 @@ def __init__(self,
# output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes))

def forward(self, features, train_nodes):
for n, feat in features.items():
self.g.nodes[n]['h'] = feat
def forward(self, features):
self.g.set_n_repr(features)
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.nodes[n]['h'] = F.dropout(g.nodes[n]['h'], p=self.dropout)
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(gcn_msg, gcn_reduce, layer)
return torch.cat([torch.unsqueeze(self.g.nodes[n]['h'], 0) for n in train_nodes])
return self.g.pop_n_repr()

def main(args):
# load and preprocess dataset
data = load_data(args)

# features of each samples
features = {}
labels = []
train_nodes = []
for n in data.graph.nodes():
features[n] = torch.FloatTensor(data.features[n, :])
if data.train_mask[n] == 1:
train_nodes.append(n)
labels.append(data.labels[n])
labels = torch.LongTensor(labels)
in_feats = data.features.shape[1]
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()

Expand All @@ -83,11 +79,13 @@ def main(args):
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = {k : v.cuda() for k, v in features.items()}
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()

# create GCN model
model = GCN(data.graph,
g = DGLGraph(data.graph)
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
Expand All @@ -107,9 +105,9 @@ def main(args):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features, train_nodes)
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels)
loss = F.nll_loss(logp[mask], labels[mask])

optimizer.zero_grad()
loss.backward()
Expand All @@ -130,7 +128,7 @@ def main(args):
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=10,
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
Expand Down
140 changes: 0 additions & 140 deletions examples/pytorch/gcn/gcn_batch.py

This file was deleted.

2 changes: 1 addition & 1 deletion examples/pytorch/gcn/gcn_spmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, features):
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(fn.copy_src(), fn.sum(), layer, batchable=True)
self.g.update_all(fn.copy_src(), fn.sum(), layer)
return self.g.pop_n_repr()

def main(args):
Expand Down
8 changes: 7 additions & 1 deletion include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@ class GraphOp {
public:
/*!
* \brief Return the line graph.
*
* If i~j and j~i are two edges in original graph G, then
* (i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
* the line graph.
*
* \param graph The input graph.
* \param backtracking Whether the backtracking edges are included or not
* \return the line graph
*/
static Graph LineGraph(const Graph* graph);
static Graph LineGraph(const Graph* graph, bool backtracking);
/*!
* \brief Return a disjoint union of the input graphs.
*
Expand Down
2 changes: 0 additions & 2 deletions python/dgl/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,10 @@ def unbatch(graph):
node_frames = [FrameRef() for i in range(bsize)]
edge_frames = [FrameRef() for i in range(bsize)]
for attr, col in graph._node_frame.items():
# TODO: device context
col_splits = F.unpack(col, bn)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
# TODO: device context
col_splits = F.unpack(col, be)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
Expand Down
6 changes: 1 addition & 5 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,11 @@ class DGLGraph(object):
Node feature storage.
edge_frame : FrameRef
Edge feature storage.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
def __init__(self,
graph_data=None,
node_frame=None,
edge_frame=None,
**attr):
# TODO: keyword attr
edge_frame=None):
# graph
self._graph = create_graph_index(graph_data)
# frame
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit fde4f58

Please sign in to comment.