Skip to content

Commit

Permalink
[NN] Add commonly used GNN models from examples to dgl.nn modules. (d…
Browse files Browse the repository at this point in the history
…mlc#748)

* gat

* upd

* upd sage

* upd

* upd

* upd

* upd

* upd

* add gmmconv

* upd ggnn

* upd

* upd

* upd

* upd

* add citation examples

* add README

* fix cheb

* improve doc

* formula

* upd

* trigger

* lint

* lint

* upd

* add test for transform

* add test

* check

* upd

* improve doc

* shape check

* upd

* densechebconv, currently not correct (?)

* fix cheb

* fix

* upd

* upd sgc-reddit

* upd

* trigger
  • Loading branch information
yzh119 authored Aug 27, 2019
1 parent 8079d98 commit 650f6ee
Show file tree
Hide file tree
Showing 29 changed files with 2,450 additions and 451 deletions.
29 changes: 29 additions & 0 deletions docs/source/api/python/function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,32 @@ Message functions
copy_src
copy_edge
src_mul_edge
copy_u
copy_e
u_add_v
u_sub_v
u_mul_v
u_div_v
u_add_e
u_sub_e
u_mul_e
u_div_e
v_add_u
v_sub_u
v_mul_u
v_div_u
v_add_e
v_sub_e
v_mul_e
v_div_e
e_add_u
e_sub_u
e_mul_u
e_div_u
e_add_v
e_sub_v
e_mul_v
e_div_v

Reduce functions
----------------
Expand All @@ -23,3 +49,6 @@ Reduce functions

sum
max
min
prod
mean
56 changes: 56 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,62 @@ dgl.nn.pytorch.conv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.TAGConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.GATConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.SAGEConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.SGConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.APPNPConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.GINConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.GatedGraphConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.GMMConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.ChebConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.AGNNConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.NNConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.DenseGraphConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.DenseSAGEConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.DenseChebConv
:members: forward
:show-inheritance:

dgl.nn.pytorch.glob
-------------------
.. automodule:: dgl.nn.pytorch.glob
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api/python/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ Transform -- Graph Transformation
reverse
to_simple_graph
to_bidirected
khop_adj
khop_graph
laplacian_lambda_max
44 changes: 4 additions & 40 deletions examples/pytorch/appnp/appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,7 @@
import torch
import torch.nn as nn
import dgl.function as fn


class GraphPropagation(nn.Module):
def __init__(self,
g,
edge_drop,
alpha,
k):
super(GraphPropagation, self).__init__()
self.g = g
self.alpha = alpha
self.k = k
if edge_drop:
self.edge_drop = nn.Dropout(edge_drop)
else:
self.edge_drop = 0.

def forward(self, h):
self.cached_h = h
for _ in range(self.k):
# normalization by square root of src degree
h = h * self.g.ndata['norm']
self.g.ndata['h'] = h
if self.edge_drop:
# performing edge dropout
ed = self.edge_drop(torch.ones((self.g.number_of_edges(), 1), device=h.device))
self.g.edata['d'] = ed
self.g.update_all(fn.src_mul_edge(src='h', edge='d', out='m'),
fn.sum(msg='m', out='h'))
else:
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
h = self.g.ndata.pop('h')
# normalization by square root of dst degree
h = h * self.g.ndata['norm']
# update h using teleport probability alpha
h = h * (1 - self.alpha) + self.cached_h * self.alpha
return h
from dgl.nn.pytorch.conv import APPNPConv


class APPNP(nn.Module):
Expand All @@ -60,6 +23,7 @@ def __init__(self,
alpha,
k):
super(APPNP, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(nn.Linear(in_feats, hiddens[0]))
Expand All @@ -73,7 +37,7 @@ def __init__(self,
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x: x
self.propagate = GraphPropagation(g, edge_drop, alpha, k)
self.propagate = APPNPConv(k, alpha, edge_drop)
self.reset_parameters()

def reset_parameters(self):
Expand All @@ -89,5 +53,5 @@ def forward(self, features):
h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(h)
h = self.propagate(h, self.g)
return h
96 changes: 13 additions & 83 deletions examples/pytorch/gat/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,78 +10,8 @@
import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax
from dgl.nn.pytorch import edge_softmax, GATConv

class GraphAttention(nn.Module):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = lambda x : x
self.attn_l = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
self.attn_r = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
self.leaky_relu = nn.LeakyReLU(alpha)
self.softmax = edge_softmax
self.residual = residual
if residual:
if in_dim != out_dim:
self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
else:
self.res_fc = None

def forward(self, inputs):
# prepare
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1) # N x H x 1
a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1) # N x H x 1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute softmax
self.edge_softmax()
# 3. compute the aggregated node features scaled by the dropped,
# unnormalized attention values.
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
ret = self.g.ndata['ft']
# 4. residual
if self.residual:
if self.res_fc is not None:
resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
else:
resval = torch.unsqueeze(h, 1) # Nx1xD'
ret = resval + ret
return ret

def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
return {'a' : a}

def edge_softmax(self):
attention = self.softmax(self.g, self.g.edata.pop('a'))
# Dropout attention scores and save them
self.g.edata['a_drop'] = self.attn_drop(attention)

class GAT(nn.Module):
def __init__(self,
Expand All @@ -94,32 +24,32 @@ def __init__(self,
activation,
feat_drop,
attn_drop,
alpha,
negative_slope,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False))
self.gat_layers.append(GATConv(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation))
# output projection
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None))

def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h).flatten(1)
h = self.activation(h)
h = self.gat_layers[l](h, self.g).flatten(1)
# output projection
logits = self.gat_layers[-1](h).mean(1)
logits = self.gat_layers[-1](h, self.g).mean(1)
return logits
6 changes: 3 additions & 3 deletions examples/pytorch/gat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(args):
F.elu,
args.in_drop,
args.attn_drop,
args.alpha,
args.negative_slope,
args.residual)
print(model)
stopper = EarlyStopping(patience=100)
Expand Down Expand Up @@ -161,8 +161,8 @@ def main(args):
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4,
help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu")
parser.add_argument('--negative-slope', type=float, default=0.2,
help="the negative slope of leaky relu")
parser.add_argument('--fastmode', action="store_true", default=False,
help="skip re-evaluate the validation set")
args = parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/gcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def main(args):
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()
# normalization
Expand Down
Loading

0 comments on commit 650f6ee

Please sign in to comment.