-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_classify.py
79 lines (71 loc) · 2.77 KB
/
test_classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Infering Relational Data with Graph Convolutional Networks
"""
import argparse
import torch as th
from functools import partial
import torch.nn.functional as F
from dgl.data.rdf import AIFB, MUTAG, BGS, AM
from entity_classify import EntityClassify
def main(args):
# load graph data
if args.dataset == 'aifb':
dataset = AIFBDataset()
elif args.dataset == 'mutag':
dataset = MUTAGDataset()
elif args.dataset == 'bgs':
dataset = BGSDataset()
elif args.dataset == 'am':
dataset = AMDataset()
else:
raise ValueError()
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
test_mask = g.nodes[category].data.pop('test_mask')
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels')
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
labels = labels.cuda()
test_idx = test_idx.cuda()
g = g.to('cuda:%d' % args.gpu)
# create model
model = EntityClassify(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop)
model.load_state_dict(th.load(args.model_path))
if use_cuda:
model.cuda()
print("start testing...")
model.eval()
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str,
help='path of the model to load from')
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
args = parser.parse_args()
print(args)
main(args)