Skip to content

Commit

Permalink
[Model] Add APPNP model (dmlc#480)
Browse files Browse the repository at this point in the history
* [Model] Add APPNP model

* update

* Revert "update"

This reverts commit a8e42d1.

* update

* Update appnp.py
  • Loading branch information
aymenwah authored and mufeili committed Apr 9, 2019
1 parent bfdd1ea commit fa887f6
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
36 changes: 36 additions & 0 deletions examples/pytorch/appnp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
Predict then Propagate: Graph Neural Networks meet Personalized PageRank (APPNP)
============

- Paper link: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/abs/1810.05997)
- Author's code repo: [https://github.com/klicperajo/ppnp](https://github.com/klicperajo/ppnp).

Dependencies
------------
- PyTorch 0.4.1+
- requests

``bash
pip install torch requests
``

Code
-----
The folder contains an implementation of APPNP (`appnp.py`).

Results
-------

Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
python train.py --dataset cora --gpu 0
```

* cora: 0.8370 (paper: 0.850)
* citeseer: 0.715 (paper: 0.757)
* pubmed: 0.793 (paper: 0.797)

Differences from the original implementation
---------

- This implementation does not perform dropout on adjacency matrices during propagation step.
- Experiments were done on dgl datasets (GCN settings) which are different from those used in the original implementation. (discrepancies are detailed in experimental section of the original paper)
70 changes: 70 additions & 0 deletions examples/pytorch/appnp/appnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
APPNP implementation in DGL.
References
----------
Paper: https://arxiv.org/abs/1810.05997
Author's code: https://github.com/klicperajo/ppnp
"""

import torch.nn as nn
import dgl.function as fn


class APPNP(nn.Module):
def __init__(self,
g,
in_feats,
hiddens,
n_classes,
activation,
dropout,
alpha,
k):
super(APPNP, self).__init__()
self.layers = nn.ModuleList()
self.g = g
# input layer
self.layers.append(nn.Linear(in_feats, hiddens[0]))
# hidden layers
for i in range(1, len(hiddens)):
self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i]))
# output layer
self.layers.append(nn.Linear(hiddens[-1], n_classes))
self.activation = activation
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
self.K = k
self.alpha = alpha

def reset_parameters(self):
for layer in self.layers:
layer.reset_parameters()

def forward(self, features):
# prediction step
h = features
if self.dropout:
h = self.dropout(h)
h = self.activation(self.layers[0](h))
for layer in self.layers[1:-1]:
h = self.activation(layer(h))
if self.dropout:
h = self.layers[-1](self.dropout(h))
# propagation step without dropout on adjacency matrices
self.cached_h = h
for _ in range(self.K):
# normalization by square root of src degree
h = h * self.g.ndata['norm']
self.g.ndata['h'] = h
# message-passing without performing adjacency dropout
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
h = self.g.ndata.pop('h')
# normalization by square root of dst degree
h = h * self.g.ndata['norm']
# update h using teleport probability alpha
h = h * (1 - self.alpha) + self.cached_h * self.alpha

return h
138 changes: 138 additions & 0 deletions examples/pytorch/appnp/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
import dgl
from appnp import APPNP

def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().item(),
val_mask.sum().item(),
test_mask.sum().item()))

if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

# graph preprocess and calculate normalization factor
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
# add self loop
g.add_edges(g.nodes(), g.nodes())
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
if cuda:
norm = norm.cuda()
g.ndata['norm'] = norm.unsqueeze(1)

# create APPNP model
model = APPNP(g,
in_feats,
args.hidden_sizes,
n_classes,
F.relu,
args.dropout,
args.alpha,
args.k)

if cuda:
model.cuda()
model.reset_parameters()
loss_fcn = torch.nn.CrossEntropyLoss()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

# initialize graph
dur = []
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])

optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch >= 3:
dur.append(time.time() - t0)

acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='APPNP')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--hidden_sizes", type=int, nargs='+', default=[64],
help="hidden unit sizes for appnp")
parser.add_argument("--k", type=int, default=10,
help="Number of propagation steps")
parser.add_argument("--alpha", type=float, default=0.1,
help="Teleport Probability")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)

main(args)

0 comments on commit fa887f6

Please sign in to comment.