forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update (dmlc#5) * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * FIx * Try * Update * Update * Update * Fix * Update * Fix * Fix * Fix * Fix * Update * Fix * Update * Update * Update * Fix * Fix * Update * Update * Update * Update * Fix * Fix * Fix * Update * Update * Update * Update * Update * Update README.md * Update * Fix * Update * Update * Fix * Fix * Fix * Update * Update * Update Co-authored-by: Ubuntu <[email protected]> * Update * Update * Fix * Update * Update * Update * Fix * Update * Update * Update * Update * Update * Update * CI Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Quan (Andy) Gan <[email protected]>
- Loading branch information
1 parent
22272de
commit e9c3c0e
Showing
21 changed files
with
1,036 additions
and
1,588 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,84 +1,74 @@ | ||
# Relational-GCN | ||
|
||
* Paper: [https://arxiv.org/abs/1703.06103](https://arxiv.org/abs/1703.06103) | ||
* Paper: [Modeling Relational Data with Graph Convolutional Networks](https://arxiv.org/abs/1703.06103) | ||
* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn) | ||
* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction) | ||
|
||
### Dependencies | ||
* PyTorch 0.4.1+ | ||
* requests | ||
* PyTorch 1.10 | ||
* rdflib | ||
* pandas | ||
* tqdm | ||
* TorchMetrics | ||
|
||
``` | ||
pip install requests torch rdflib pandas | ||
pip install rdflib pandas | ||
``` | ||
|
||
Example code was tested with rdflib 4.2.2 and pandas 0.23.4 | ||
|
||
### Entity Classification | ||
AIFB: accuracy 96.29% (3 runs, DGL), 95.83% (paper) | ||
``` | ||
python3 entity_classify.py -d aifb --testing --gpu 0 | ||
python entity.py -d aifb --l2norm 0 --gpu 0 | ||
``` | ||
|
||
MUTAG: accuracy 70.59% (3 runs, DGL), 73.23% (paper) | ||
MUTAG: accuracy 72.55% (3 runs, DGL), 73.23% (paper) | ||
``` | ||
python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 | ||
python entity.py -d mutag --n-bases 30 --gpu 0 | ||
``` | ||
|
||
BGS: accuracy 93.10% (3 runs, DGL), 83.10% (paper) | ||
BGS: accuracy 89.70% (3 runs, DGL), 83.10% (paper) | ||
``` | ||
python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 | ||
python entity.py -d bgs --n-bases 40 --gpu 0 | ||
``` | ||
|
||
AM: accuracy 89.22% (3 runs, DGL), 89.29% (paper) | ||
AM: accuracy 89.56% (3 runs, DGL), 89.29% (paper) | ||
``` | ||
python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --testing | ||
python entity.py -d am --n-bases 40 --n-hidden 10 | ||
``` | ||
|
||
### Entity Classification with minibatch | ||
AIFB: accuracy avg(5 runs) 90.00%, best 94.44% (DGL) | ||
``` | ||
python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout='20,20' --batch-size 128 | ||
``` | ||
|
||
MUTAG: accuracy avg(10 runs) 62.94%, best 72.06% (DGL) | ||
``` | ||
python3 entity_classify_mp.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --dgl-sparse --n-epochs 20 --sparse-lr 0.01 --dropout 0.5 | ||
``` | ||
|
||
BGS: accuracy avg(5 runs) 78.62%, best 86.21% (DGL) | ||
AIFB: accuracy avg(5 runs) 91.10%, best 97.22% (DGL) | ||
``` | ||
python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --dgl-sparse --lr 0.01 --sparse-lr 0.05 --dropout 0.3 | ||
python entity_sample.py -d aifb --l2norm 0 --gpu 0 --fanout='20,20' --batch-size 128 | ||
``` | ||
|
||
AM: accuracy avg(5 runs) 87.37%, best 89.9% (DGL) | ||
MUTAG: accuracy avg(10 runs) 66.47%, best 72.06% (DGL) | ||
``` | ||
python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dgl-sparse --lr 0.01 --sparse-lr 0.02 --dropout 0.7 | ||
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --n-epochs 20 --sparse-lr 0.01 --dropout 0.5 | ||
``` | ||
|
||
### Entity Classification on OGBN-MAG | ||
Test-bd: P3-8xlarge | ||
|
||
OGBN-MAG accuracy 45.5 (3 runs) | ||
BGS: accuracy avg(5 runs) 84.83%, best 89.66% (DGL) | ||
``` | ||
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='30,30' --batch-size 1024 --n-hidden 128 --lr 0.01 --num-worker 4 --eval-batch-size 8 --low-mem --gpu 0,1,2,3 --dropout 0.7 --use-self-loop --n-bases 2 --n-epochs 3 --node-feats --dgl-sparse --sparse-lr 0.08 | ||
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --sparse-lr 0.05 --dropout 0.3 | ||
``` | ||
|
||
OGBN-MAG without node-feats 42.79 | ||
AM: accuracy avg(5 runs) 88.58%, best 89.90% (DGL) | ||
``` | ||
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='30,30' --batch-size 1024 --n-hidden 128 --lr 0.01 --num-worker 4 --eval-batch-size 8 --low-mem --gpu 0,1,2,3 --dropout 0.7 --use-self-loop --n-bases 2 --n-epochs 3 --dgl-sparse --sparse-lr 0.08 | ||
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --sparse-lr 0.02 --dropout 0.7 | ||
``` | ||
|
||
Test-bd: P2-8xlarge | ||
To use multiple GPUs, replace `entity_sample.py` with `entity_sample_multi_gpu.py` and specify | ||
multiple GPU IDs separated by comma, e.g., `--gpu 0,1`. | ||
|
||
### Link Prediction | ||
FB15k-237: MRR 0.151 (DGL), 0.158 (paper) | ||
FB15k-237: MRR 0.163 (DGL), 0.158 (paper) | ||
``` | ||
python3 link_predict.py -d FB15k-237 --gpu 0 --eval-protocol raw | ||
python link.py --gpu 0 --eval-protocol raw | ||
``` | ||
FB15k-237: Filtered-MRR 0.2044 | ||
FB15k-237: Filtered-MRR 0.247 | ||
``` | ||
python3 link_predict.py -d FB15k-237 --gpu 0 --eval-protocol filtered | ||
python link.py --gpu 0 --eval-protocol filtered | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
Differences compared to tkipf/relation-gcn | ||
* l2norm applied to all weights | ||
* remove nodes that won't be touched | ||
""" | ||
|
||
import argparse | ||
import torch as th | ||
import torch.nn.functional as F | ||
|
||
from torchmetrics.functional import accuracy | ||
|
||
from entity_utils import load_data | ||
from model import RGCN | ||
|
||
def main(args): | ||
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data( | ||
args.dataset, get_norm=True) | ||
|
||
num_nodes = g.num_nodes() | ||
|
||
# Since the nodes are featureless, learn node embeddings from scratch | ||
# This requires passing the node IDs to the model. | ||
feats = th.arange(num_nodes) | ||
|
||
model = RGCN(num_nodes, | ||
args.n_hidden, | ||
num_classes, | ||
num_rels, | ||
num_bases=args.n_bases) | ||
|
||
if args.gpu >= 0 and th.cuda.is_available(): | ||
device = th.device(args.gpu) | ||
else: | ||
device = th.device('cpu') | ||
feats = feats.to(device) | ||
labels = labels.to(device) | ||
model = model.to(device) | ||
g = g.to(device) | ||
|
||
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm) | ||
|
||
model.train() | ||
for epoch in range(50): | ||
logits = model(g, feats) | ||
logits = logits[target_idx] | ||
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item() | ||
print("Epoch {:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format( | ||
epoch, train_acc, loss.item())) | ||
print() | ||
|
||
model.eval() | ||
with th.no_grad(): | ||
logits = model(g, feats) | ||
logits = logits[target_idx] | ||
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() | ||
print("Test Accuracy: {:.4f}".format(test_acc)) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='RGCN for entity classification') | ||
parser.add_argument("--n-hidden", type=int, default=16, | ||
help="number of hidden units") | ||
parser.add_argument("--gpu", type=int, default=-1, | ||
help="gpu") | ||
parser.add_argument("--n-bases", type=int, default=-1, | ||
help="number of filter weight matrices, default: -1 [use all]") | ||
parser.add_argument("-d", "--dataset", type=str, required=True, | ||
choices=['aifb', 'mutag', 'bgs', 'am'], | ||
help="dataset to use") | ||
parser.add_argument("--l2norm", type=float, default=5e-4, | ||
help="l2 norm coef") | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
main(args) |
Oops, something went wrong.