Skip to content

Commit

Permalink
[API][Doc] API change & basic tutorials (dmlc#113)
Browse files Browse the repository at this point in the history
* Add SH tutorials

* setup sphinx-gallery; work on graph tutorial

* draft dglgraph tutorial

* update readme to include document url

* rm obsolete file

* Draft the message passing tutorial

* Capsule code (dmlc#102)

* add capsule example

* clean code

* better naming

* better naming

* [GCN]tutorial scaffold

* fix capsule example code

* remove previous capsule example code

* graph struc edit

* modified:   2_graph.py

* update doc of capsule

* update capsule docs

* update capsule docs

* add msg passing prime

* GCN-GAT tutorial Section 1 and 2

* comment for API improvement

* section 3

* Tutorial API change (dmlc#115)

* change the API as discusses; toy example

* enable the new set/get syntax

* fixed pytorch utest

* fixed gcn example

* fixed gat example

* fixed mx utests

* fix mx utest

* delete apply edges; add utest for update_edges

* small change on toy example

* fix utest

* fix out in degrees bug

* update pagerank example and add it to CI

* add delitem for dataview

* make edges() return form that is compatible with send/update_edges etc

* fix index bug when the given data is one-int-tensor

* fix doc
  • Loading branch information
jermainewang authored Nov 2, 2018
1 parent 2ecd2b2 commit 68ec624
Show file tree
Hide file tree
Showing 29 changed files with 1,829 additions and 911 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def build_dgl() {
}
dir ('build') {
sh 'cmake ..'
sh 'make -j$(nproc)'
sh 'make -j4'
}
}

Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
]

Expand Down Expand Up @@ -193,4 +194,4 @@
'gallery_dirs' : gallery_dirs,
'within_subsection_order' : FileNameSortKey,
'filename_pattern' : '.py',
}
}
35 changes: 0 additions & 35 deletions examples/pagerank.py

This file was deleted.

24 changes: 12 additions & 12 deletions examples/pytorch/gat/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}

class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop

def forward(self, node, msgs):
a1 = torch.unsqueeze(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
def forward(self, nodes):
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
Expand All @@ -46,13 +46,13 @@ def __init__(self, headid, indim, hiddendim, activation, residual):
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)

def forward(self, node):
ret = node['accum']
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = node['h'] + ret
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}

class GATPrepare(nn.Module):
Expand Down Expand Up @@ -120,15 +120,15 @@ def forward(self, features):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.set_n_repr(self.prp[i](last))
self.g.ndata.update(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
last = torch.cat(
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection
self.g.set_n_repr(self.prp[-1](last))
self.g.ndata.update(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')

Expand Down
18 changes: 8 additions & 10 deletions examples/pytorch/gcn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,22 @@
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

def gcn_msg(src, edge):
return {'m' : src['h']}
def gcn_msg(edges):
return {'m' : edges.src['h']}

def gcn_reduce(node, msgs):
return {'h' : torch.sum(msgs['m'], 1)}
def gcn_reduce(nodes):
return {'h' : torch.sum(nodes.mailbox['m'], 1)}

class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()

self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation

def forward(self, node):
h = self.linear(node['h'])
def forward(self, nodes):
h = self.linear(nodes.data['h'])
if self.activation:
h = self.activation(h)

return {'h' : h}

class GCN(nn.Module):
Expand Down Expand Up @@ -62,13 +60,13 @@ def __init__(self,
self.layers.append(NodeApplyModule(n_hidden, n_classes))

def forward(self, features):
self.g.set_n_repr({'h' : features})
self.g.ndata['h'] = features

for layer in self.layers:
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h')

Expand Down
8 changes: 4 additions & 4 deletions examples/pytorch/gcn/gcn_spmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, in_feats, out_feats, activation=None):
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation

def forward(self, node):
h = self.linear(node['h'])
def forward(self, nodes):
h = self.linear(nodes.data['h'])
if self.activation:
h = self.activation(h)

Expand Down Expand Up @@ -57,13 +57,13 @@ def __init__(self,
self.layers.append(NodeApplyModule(n_hidden, n_classes))

def forward(self, features):
self.g.set_n_repr({'h' : features})
self.g.ndata['h'] = features

for layer in self.layers:
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
Expand Down
24 changes: 24 additions & 0 deletions examples/pytorch/pagerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import networkx as nx
import torch
import dgl
import dgl.function as fn

N = 100
g = nx.nx.erdos_renyi_graph(N, 0.05)
g = dgl.DGLGraph(g)

DAMP = 0.85
K = 10

def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_src(src='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv']

pv = compute_pagerank(g)
print(pv)
1 change: 1 addition & 0 deletions python/dgl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .batched_graph import *
from .graph import DGLGraph
from .subgraph import DGLSubGraph
from .udf import NodeBatch, EdgeBatch
34 changes: 12 additions & 22 deletions python/dgl/function/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class MessageFunction(object):
"""Base builtin message function class."""

def __call__(self, src, edge):
def __call__(self, edges):
"""Regular computation of this builtin.
This will be used when optimization is not available.
Expand Down Expand Up @@ -38,15 +38,11 @@ def is_spmv_supported(self, g):
return False
return True

def __call__(self, src, edge):
ret = None
def __call__(self, edges):
ret = dict()
for fn in self.fn_list:
msg = fn(src, edge)
if ret is None:
ret = msg
else:
# ret and msg must be dict
ret.update(msg)
msg = fn(edges)
ret.update(msg)
return ret

def name(self):
Expand Down Expand Up @@ -83,8 +79,9 @@ def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)

def __call__(self, src, edge):
ret = self.mul_op(src[self.src_field], edge[self.edge_field])
def __call__(self, edges):
ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field])
return {self.out_field : ret}

def name(self):
Expand All @@ -98,8 +95,8 @@ def __init__(self, src_field, out_field):
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)

def __call__(self, src, edge):
return {self.out_field : src[self.src_field]}
def __call__(self, edges):
return {self.out_field : edges.src[self.src_field]}

def name(self):
return "copy_src"
Expand All @@ -114,15 +111,8 @@ def is_spmv_supported(self, g):
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)

def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def __call__(self, edges):
return {self.out_field : edges.data[self.edge_field]}

def name(self):
return "copy_edge"
Expand Down
18 changes: 7 additions & 11 deletions python/dgl/function/reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class ReduceFunction(object):
"""Base builtin reduce function class."""

def __call__(self, node, msgs):
def __call__(self, nodes):
"""Regular computation of this builtin.
This will be used when optimization is not available.
Expand All @@ -35,15 +35,11 @@ def is_spmv_supported(self):
return False
return True

def __call__(self, node, msgs):
ret = None
def __call__(self, nodes):
ret = dict()
for fn in self.fn_list:
rpr = fn(node, msgs)
if ret is None:
ret = rpr
else:
# ret and rpr must be dict
ret.update(rpr)
rpr = fn(nodes)
ret.update(rpr)
return ret

def name(self):
Expand All @@ -60,8 +56,8 @@ def is_spmv_supported(self):
# NOTE: only sum is supported right now.
return self.name == "sum"

def __call__(self, node, msgs):
return {self.out_field : self.op(msgs[self.msg_field], 1)}
def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}

def name(self):
return self.name
Expand Down
Loading

0 comments on commit 68ec624

Please sign in to comment.