Skip to content

Commit

Permalink
[NN] Rework RelGraphConv and HGTConv (dmlc#3742)
Browse files Browse the repository at this point in the history
* WIP: TypedLinear and new RelGraphConv

* wip

* further simplify RGCN

* a bunch of tweak for performance; add basic cpu support

* update on segmm

* wip: segment.cu

* new backward kernel works

* fix a bunch of bugs in kernel; leave idx_a for future

* add nn test for typed_linear

* rgcn nn test

* bugfix in corner case; update RGCN README

* doc

* fix cpp lint

* fix lint

* fix ut

* wip: hgtconv; presorted flag for rgcn

* hgt code and ut; WIP: some fix on reorder graph

* better typed linear init

* fix ut

* fix lint; add docstring
  • Loading branch information
jermainewang authored Feb 23, 2022
1 parent 4f00d5a commit 0227ddf
Show file tree
Hide file tree
Showing 28 changed files with 1,272 additions and 1,322 deletions.
17 changes: 6 additions & 11 deletions benchmarks/benchmarks/model_speed/bench_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,19 @@ def __init__(self,
num_rels,
num_bases,
num_hidden_layers,
dropout,
lowmem):
dropout):
super(RGCN, self).__init__()
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConv(num_nodes, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=lowmem))
num_bases, activation=F.relu, dropout=dropout))
# h2h
for i in range(num_hidden_layers):
self.layers.append(RelGraphConv(n_hidden, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=lowmem))
num_bases, activation=F.relu, dropout=dropout))
# o2h
self.layers.append(RelGraphConv(n_hidden, num_classes, num_rels, "basis",
num_bases, activation=None, low_mem=lowmem))
num_bases, activation=None))

def forward(self, g, h, r, norm):
for layer in self.layers:
Expand All @@ -40,9 +37,8 @@ def forward(self, g, h, r, norm):

@utils.benchmark('time', 300)
@utils.parametrize('data', ['aifb'])
@utils.parametrize('lowmem', [True, False])
@utils.parametrize('use_type_count', [True, False])
def track_time(data, lowmem, use_type_count):
def track_time(data, use_type_count):
# args
if data == 'aifb':
num_bases = -1
Expand Down Expand Up @@ -108,8 +104,7 @@ def track_time(data, lowmem, use_type_count):
num_rels,
num_bases,
0,
0,
lowmem).to(device)
0).to(device)

optimizer = torch.optim.Adam(model.parameters(),
lr=1e-2,
Expand Down
10 changes: 9 additions & 1 deletion docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ TransR
:members: rel_emb, rel_project, forward, reset_parameters
:show-inheritance:

Heterogeneous Graph Convolution Module
Heterogeneous Learning Module
----------------------------------------

HeteroGraphConv
Expand All @@ -319,9 +319,17 @@ HeteroEmbedding

.. _apinn-pytorch-util:


Utility Modules
----------------------------------------

TypedLinear
----------------------------------------

.. autoclass:: dgl.nn.pytorch.TypedLinear
:members: forward
:show-inheritance:

Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
43 changes: 11 additions & 32 deletions examples/pytorch/rgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,36 @@ pip install rdflib pandas
Example code was tested with rdflib 4.2.2 and pandas 0.23.4

### Entity Classification
AIFB: accuracy 96.29% (3 runs, DGL), 95.83% (paper)
```
python entity.py -d aifb --l2norm 0 --gpu 0
```

MUTAG: accuracy 72.55% (3 runs, DGL), 73.23% (paper)
For AIFB, MUTAG, BGS and AM,
```
python entity.py -d aifb --wd 0 --gpu 0
python entity.py -d mutag --n-bases 30 --gpu 0
```

BGS: accuracy 89.70% (3 runs, DGL), 83.10% (paper)
```
python entity.py -d bgs --n-bases 40 --gpu 0
```

AM: accuracy 89.56% (3 runs, DGL), 89.29% (paper)
```
python entity.py -d am --n-bases 40 --n-hidden 10
python entity.py -d am --n-bases 40 --n-hidden 10 --gpu 0
```

### Entity Classification with minibatch

AIFB: accuracy avg(5 runs) 91.10%, best 97.22% (DGL)
For AIFB, MUTAG, BGS and AM,
```
python entity_sample.py -d aifb --l2norm 0 --gpu 0 --fanout='20,20' --batch-size 128
python entity_sample.py -d aifb --wd 0 --gpu 0 --fanout='20,20' --batch-size 128
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout='-1,-1' --use-self-loop --n-epochs 20 --dropout 0.5
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout='-1,-1' --n-epochs=16 --batch-size=16 --dropout 0.3
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout='35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dropout 0.7
```

MUTAG: accuracy avg(10 runs) 66.47%, best 72.06% (DGL)
```
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --n-epochs 20 --sparse-lr 0.01 --dropout 0.5
```

BGS: accuracy avg(5 runs) 84.83%, best 89.66% (DGL)
```
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --sparse-lr 0.05 --dropout 0.3
```

AM: accuracy avg(5 runs) 88.58%, best 89.90% (DGL)
```
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --sparse-lr 0.02 --dropout 0.7
```
### Entity Classification on multiple GPUs

To use multiple GPUs, replace `entity_sample.py` with `entity_sample_multi_gpu.py` and specify
multiple GPU IDs separated by comma, e.g., `--gpu 0,1`.

### Link Prediction
FB15k-237: MRR 0.163 (DGL), 0.158 (paper)
FB15k-237 in RAW-MRR
```
python link.py --gpu 0 --eval-protocol raw
```
FB15k-237: Filtered-MRR 0.247
FB15k-237 in Filtered-MRR
```
python link.py --gpu 0 --eval-protocol filtered
```
27 changes: 9 additions & 18 deletions examples/pytorch/rgcn/entity.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""
Differences compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
* weight decay applied to all weights
"""

import argparse
import torch as th
import torch.nn.functional as F
Expand All @@ -17,13 +15,7 @@ def main(args):
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data(
args.dataset, get_norm=True)

num_nodes = g.num_nodes()

# Since the nodes are featureless, learn node embeddings from scratch
# This requires passing the node IDs to the model.
feats = th.arange(num_nodes)

model = RGCN(num_nodes,
model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
Expand All @@ -33,16 +25,15 @@ def main(args):
device = th.device(args.gpu)
else:
device = th.device('cpu')
feats = feats.to(device)
labels = labels.to(device)
model = model.to(device)
g = g.to(device)
g = g.int().to(device)

optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)

model.train()
for epoch in range(50):
logits = model(g, feats)
for epoch in range(100):
logits = model(g)
logits = logits[target_idx]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
optimizer.zero_grad()
Expand All @@ -56,7 +47,7 @@ def main(args):

model.eval()
with th.no_grad():
logits = model(g, feats)
logits = model(g)
logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
print("Test Accuracy: {:.4f}".format(test_acc))
Expand All @@ -72,8 +63,8 @@ def main(args):
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4,
help="l2 norm coef")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")

args = parser.parse_args()
print(args)
Expand Down
63 changes: 22 additions & 41 deletions examples/pytorch/rgcn/entity_sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Differences compared to tkipf/relation-gcn
* l2norm applied to all weights
* weight decay applied to all weights
* remove nodes that won't be touched
"""
import argparse
Expand All @@ -13,7 +13,7 @@
from tqdm import tqdm

from entity_utils import load_data
from model import RelGraphEmbedLayer, RGCN
from model import RGCN

def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False):
fanouts = [int(fanout) for fanout in args.fanout.split(',')]
Expand Down Expand Up @@ -54,21 +54,6 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F

return train_loader, val_loader, test_loader

def init_models(args, device, num_nodes, num_classes, num_rels):
embed_layer = RelGraphEmbedLayer(device,
num_nodes,
args.n_hidden)

model = RGCN(args.n_hidden,
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop)

return embed_layer, model

def process_batch(inv_target, batch):
_, seeds, blocks = batch
# map the seed nodes back to their type-specific ids,
Expand All @@ -80,38 +65,32 @@ def process_batch(inv_target, batch):

return seeds, blocks

def train(model, embed_layer, train_loader, inv_target,
labels, emb_optimizer, optimizer):
def train(model, train_loader, inv_target,
labels, optimizer):
model.train()
embed_layer.train()

for sample_data in train_loader:
seeds, blocks = process_batch(inv_target, sample_data)
feats = embed_layer(blocks[0].srcdata[dgl.NID].cpu())
logits = model(blocks, feats)
logits = model.forward(blocks)
loss = F.cross_entropy(logits, labels[seeds])
emb_optimizer.zero_grad()
optimizer.zero_grad()

optimizer.zero_grad()
loss.backward()
emb_optimizer.step()
optimizer.step()

train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item()

return train_acc, loss.item()

def evaluate(model, embed_layer, eval_loader, inv_target):
def evaluate(model, eval_loader, inv_target):
model.eval()
embed_layer.eval()
eval_logits = []
eval_seeds = []

with th.no_grad():
for sample_data in tqdm(eval_loader):
seeds, blocks = process_batch(inv_target, sample_data)
feats = embed_layer(blocks[0].srcdata[dgl.NID].cpu())
logits = model(blocks, feats)
logits = model.forward(blocks)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(seeds.cpu().detach())

Expand All @@ -131,26 +110,30 @@ def main(args):

train_loader, val_loader, test_loader = init_dataloaders(
args, g, train_idx, test_idx, target_idx, args.gpu)
embed_layer, model = init_models(args, device, g.num_nodes(), num_classes, num_rels)

model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop,
ns_mode=True)
labels = labels.to(device)
model = model.to(device)

emb_optimizer = th.optim.SparseAdam(embed_layer.parameters(), lr=args.sparse_lr)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)

for epoch in range(args.n_epochs):
train_acc, loss = train(model, embed_layer, train_loader, inv_target,
labels, emb_optimizer, optimizer)
train_acc, loss = train(model, train_loader, inv_target, labels, optimizer)
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
epoch, args.n_epochs, train_acc, loss))

val_logits, val_seeds = evaluate(model, embed_layer, val_loader, inv_target)
val_logits, val_seeds = evaluate(model, val_loader, inv_target)
val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item()
print("Validation Accuracy: {:.4f}".format(val_acc))

test_logits, test_seeds = evaluate(model, embed_layer,
test_loader, inv_target)
test_logits, test_seeds = evaluate(model, test_loader, inv_target)
test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item()
print("Final Test Accuracy: {:.4f}".format(test_acc))

Expand All @@ -162,17 +145,15 @@ def main(args):
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--sparse-lr", type=float, default=2e-2,
help="sparse embedding learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4,
help="l2 norm coef")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true',
Expand Down
Loading

0 comments on commit 0227ddf

Please sign in to comment.