From cffa4034f580e33fe4295e9f1b54217e7fa724eb Mon Sep 17 00:00:00 2001 From: Ziyue Huang Date: Mon, 17 Dec 2018 01:44:41 +0800 Subject: [PATCH] [Model] fix GCN (#305) * mxnet gcn spmv * update readme * fix gcn * pytorch gcn * update readme --- examples/mxnet/gcn/README.md | 5 +- examples/mxnet/gcn/gcn.py | 220 +++++++++++++++++++++++++++++++ examples/mxnet/gcn/gcn_batch.py | 190 -------------------------- examples/mxnet/gcn/gcn_concat.py | 195 +++++++++++++++++++++++++++ examples/pytorch/gcn/gcn.py | 120 +++++++++++------ 5 files changed, 497 insertions(+), 233 deletions(-) create mode 100644 examples/mxnet/gcn/gcn.py delete mode 100644 examples/mxnet/gcn/gcn_batch.py create mode 100644 examples/mxnet/gcn/gcn_concat.py diff --git a/examples/mxnet/gcn/README.md b/examples/mxnet/gcn/README.md index 2634bf305ac1..dd83bf49b96e 100644 --- a/examples/mxnet/gcn/README.md +++ b/examples/mxnet/gcn/README.md @@ -6,10 +6,13 @@ Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn) Codes ----- -The folder contains two implementations of GCN. `gcn_batch.py` uses user-defined +The folder contains two implementations of GCN. `gcn.py` uses user-defined message and reduce functions. `gcn_spmv.py` uses DGL's builtin functions so SPMV optimization could be applied. +The provided implementation in `gcn_concat.py` is a bit different from the +original paper for better performance, credit to @yifeim and @ZiyueHuang. + Results ------- These results are based on single-run training to minimize the cross-entropy diff --git a/examples/mxnet/gcn/gcn.py b/examples/mxnet/gcn/gcn.py new file mode 100644 index 000000000000..3463a6b821e8 --- /dev/null +++ b/examples/mxnet/gcn/gcn.py @@ -0,0 +1,220 @@ +""" +Semi-Supervised Classification with Graph Convolutional Networks +Paper: https://arxiv.org/abs/1609.02907 +Code: https://github.com/tkipf/gcn +GCN with SPMV optimization +""" +import argparse, time, math +import numpy as np +import mxnet as mx +from mxnet import gluon +import dgl +from dgl import DGLGraph +from dgl.data import register_data_args, load_data + + +def gcn_msg(edge): + msg = edge.src['h'] * edge.src['norm'] + return {'m': msg} + + +def gcn_reduce(node): + accum = mx.nd.sum(node.mailbox['m'], 1) * node.data['norm'] + return {'h': accum} + + +class NodeUpdate(gluon.Block): + def __init__(self, out_feats, activation=None, bias=True): + super(NodeUpdate, self).__init__() + with self.name_scope(): + if bias: + self.bias = self.params.get('bias', shape=(out_feats,), + init=mx.init.Zero()) + else: + self.bias = None + self.activation = activation + + def forward(self, node): + h = node.data['h'] + if self.bias is not None: + h = h + self.bias.data(h.context) + if self.activation: + h = self.activation(h) + return {'h': h} + + +class GCNLayer(gluon.Block): + def __init__(self, + g, + in_feats, + out_feats, + activation, + dropout, + bias=True): + super(GCNLayer, self).__init__() + self.g = g + self.dropout = dropout + with self.name_scope(): + self.weight = self.params.get('weight', shape=(in_feats, out_feats), + init=mx.init.Xavier()) + self.node_update = NodeUpdate(out_feats, activation, bias) + + def forward(self, h): + if self.dropout: + h = mx.nd.Dropout(h, p=self.dropout) + h = mx.nd.dot(h, self.weight.data(h.context)) + self.g.ndata['h'] = h + self.g.update_all(gcn_msg, gcn_reduce, self.node_update) + h = self.g.ndata.pop('h') + return h + + +class GCN(gluon.Block): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout, + normalization): + super(GCN, self).__init__() + self.layers = gluon.nn.Sequential() + # input layer + self.layers.add(GCNLayer(g, in_feats, n_hidden, activation, 0)) + # hidden layers + for i in range(n_layers - 1): + self.layers.add(GCNLayer(g, n_hidden, n_hidden, activation, dropout)) + # output layer + self.layers.add(GCNLayer(g, n_hidden, n_classes, None, dropout)) + + + def forward(self, features): + h = features + for layer in self.layers: + h = layer(h) + return h + +def evaluate(model, features, labels, mask): + pred = model(features).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + +def main(args): + # load and preprocess dataset + data = load_data(args) + + if args.self_loop: + data.graph.add_edges_from([(i,i) for i in range(len(data.graph))]) + + features = mx.nd.array(data.features) + labels = mx.nd.array(data.labels) + train_mask = mx.nd.array(data.train_mask) + val_mask = mx.nd.array(data.val_mask) + test_mask = mx.nd.array(data.test_mask) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + cuda = False + ctx = mx.cpu(0) + else: + cuda = True + ctx = mx.gpu(args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # create GCN model + g = DGLGraph(data.graph) + # normalization + degs = g.in_degrees().astype('float32') + norm = mx.nd.power(degs, -0.5) + if cuda: + norm = norm.as_in_context(ctx) + g.ndata['norm'] = mx.nd.expand_dims(norm, 1) + + model = GCN(g, + in_feats, + args.n_hidden, + n_classes, + args.n_layers, + mx.nd.relu, + args.dropout, + args.normalization) + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + # use optimizer + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(features) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + dur.append(time.time() - t0) + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + # test set accuracy + acc = evaluate(model, features, labels, test_mask) + print("Test accuracy {:.2%}".format(acc)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GCN') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--normalization", + choices=['sym','left'], default=None, + help="graph normalization types (default=None)") + parser.add_argument("--self-loop", action='store_true', + help="graph self-loop (default=False)") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + args = parser.parse_args() + + print(args) + + main(args) diff --git a/examples/mxnet/gcn/gcn_batch.py b/examples/mxnet/gcn/gcn_batch.py deleted file mode 100644 index 637d53e5c8b5..000000000000 --- a/examples/mxnet/gcn/gcn_batch.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -Semi-Supervised Classification with Graph Convolutional Networks -Paper: https://arxiv.org/abs/1609.02907 -Code: https://github.com/tkipf/gcn - -GCN with batch processing -""" -import argparse -import numpy as np -import time -import mxnet as mx -from mxnet import gluon -import dgl -from dgl import DGLGraph -from dgl.data import register_data_args, load_data -from functools import partial - -def gcn_msg(edge, normalization=None): - # print('h', edge.src['h'].shape, edge.src['out_degree']) - msg = edge.src['h'] - if normalization == 'sym': - msg = msg / edge.src['out_degree'].sqrt().reshape((-1,1)) - return {'m': msg} - - -def gcn_reduce(node, normalization=None): - # print('m', node.mailbox['m'].shape, node.data['in_degree']) - accum = mx.nd.sum(node.mailbox['m'], 1) - if normalization == 'sym': - accum = accum / node.data['in_degree'].sqrt().reshape((-1,1)) - elif normalization == 'left': - accum = accum / node.data['in_degree'].reshape((-1,1)) - return {'accum': accum} - - -class NodeUpdateModule(gluon.Block): - def __init__(self, out_feats, activation=None, dropout=0): - super(NodeUpdateModule, self).__init__() - self.linear = gluon.nn.Dense(out_feats, activation=activation) - self.dropout = dropout - - def forward(self, node): - accum = self.linear(node.data['accum']) - if self.dropout: - accum = mx.nd.Dropout(accum, p=self.dropout) - return {'h': mx.nd.concat(node.data['h'], accum, dim=1)} - - -class GCN(gluon.Block): - def __init__(self, - g, - in_feats, - n_hidden, - n_classes, - n_layers, - activation, - dropout, - normalization, - ): - super(GCN, self).__init__() - self.g = g - self.dropout = dropout - - self.inp_layer = gluon.nn.Dense(n_hidden, activation) - - self.conv_layers = gluon.nn.Sequential() - for i in range(n_layers): - self.conv_layers.add(NodeUpdateModule(n_hidden, activation, dropout)) - - self.out_layer = gluon.nn.Dense(n_classes) - - self.gcn_msg = partial(gcn_msg, normalization=normalization) - self.gcn_reduce = partial(gcn_reduce, normalization=normalization) - - - def forward(self, features): - emb_inp = [features, self.inp_layer(features)] - if self.dropout: - emb_inp[-1] = mx.nd.Dropout(emb_inp[-1], p=self.dropout) - - self.g.ndata['h'] = mx.nd.concat(*emb_inp, dim=1) - for layer in self.conv_layers: - self.g.update_all(self.gcn_msg, self.gcn_reduce, layer) - - emb_out = self.g.ndata.pop('h') - return self.out_layer(emb_out) - - -def main(args): - # load and preprocess dataset - data = load_data(args) - - if args.self_loop: - data.graph.add_edges_from([(i,i) for i in range(len(data.graph))]) - - features = mx.nd.array(data.features) - labels = mx.nd.array(data.labels) - mask = mx.nd.array(data.train_mask) - in_degree = mx.nd.array([data.graph.in_degree(i) - for i in range(len(data.graph))]) - out_degree = mx.nd.array([data.graph.out_degree(i) - for i in range(len(data.graph))]) - - in_feats = features.shape[1] - n_classes = data.num_labels - n_edges = data.graph.number_of_edges() - - if args.gpu <= 0: - cuda = False - ctx = mx.cpu(0) - else: - cuda = True - features = features.as_in_context(mx.gpu(0)) - labels = labels.as_in_context(mx.gpu(0)) - mask = mask.as_in_context(mx.gpu(0)) - in_degree = in_degree.as_in_context(mx.gpu(0)) - out_degree = out_degree.as_in_context(mx.gpu(0)) - ctx = mx.gpu(0) - - # create GCN model - g = DGLGraph(data.graph) - g.ndata['in_degree'] = in_degree - g.ndata['out_degree'] = out_degree - - model = GCN(g, - in_feats, - args.n_hidden, - n_classes, - args.n_layers, - 'relu', - args.dropout, - args.normalization, - ) - model.initialize(ctx=ctx) - loss_fcn = gluon.loss.SoftmaxCELoss() - - # use optimizer - trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr}) - - # initialize graph - dur = [] - for epoch in range(args.n_epochs): - if epoch >= 3: - t0 = time.time() - # forward - with mx.autograd.record(): - pred = model(features) - loss = loss_fcn(pred, labels, mask) - - #optimizer.zero_grad() - loss.backward() - trainer.step(features.shape[0]) - - if epoch >= 3: - dur.append(time.time() - t0) - print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( - epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000)) - - # test set accuracy - pred = model(features) - accuracy = (pred*100).softmax().pick(labels).mean() - print("Final accuracy {:.2%}".format(accuracy.mean().asscalar())) - return accuracy.mean().asscalar() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='GCN') - register_data_args(parser) - parser.add_argument("--dropout", type=float, default=0.5, - help="dropout probability") - parser.add_argument("--gpu", type=int, default=-1, - help="gpu") - parser.add_argument("--lr", type=float, default=1e-3, - help="learning rate") - parser.add_argument("--n-epochs", type=int, default=20, - help="number of training epochs") - parser.add_argument("--n-hidden", type=int, default=16, - help="number of hidden gcn units") - parser.add_argument("--n-layers", type=int, default=2, - help="number of hidden gcn layers") - parser.add_argument("--normalization", - choices=['sym','left'], default=None, - help="graph normalization types (default=None)") - parser.add_argument("--self-loop", action='store_true', - help="graph self-loop (default=False)") - args = parser.parse_args() - - print(args) - - main(args) diff --git a/examples/mxnet/gcn/gcn_concat.py b/examples/mxnet/gcn/gcn_concat.py new file mode 100644 index 000000000000..8ee88501a9d5 --- /dev/null +++ b/examples/mxnet/gcn/gcn_concat.py @@ -0,0 +1,195 @@ +""" +Semi-Supervised Classification with Graph Convolutional Networks +Paper: https://arxiv.org/abs/1609.02907 +Code: https://github.com/tkipf/gcn +GCN with batch processing +""" +import argparse +import numpy as np +import time +import mxnet as mx +from mxnet import gluon +import dgl +import dgl.function as fn +from dgl import DGLGraph +from dgl.data import register_data_args, load_data + + +class GCNLayer(gluon.Block): + def __init__(self, + g, + out_feats, + activation, + dropout): + super(GCNLayer, self).__init__() + self.g = g + self.dense = gluon.nn.Dense(out_feats, activation) + self.dropout = dropout + + def forward(self, h): + self.g.ndata['h'] = h * self.g.ndata['out_norm'] + self.g.update_all(fn.copy_src(src='h', out='m'), + fn.sum(msg='m', out='accum')) + accum = self.g.ndata.pop('accum') + accum = self.dense(accum * self.g.ndata['in_norm']) + if self.dropout: + accum = mx.nd.Dropout(accum, p=self.dropout) + h = self.g.ndata.pop('h') + h = mx.nd.concat(h / self.g.ndata['out_norm'], accum, dim=1) + return h + + +class GCN(gluon.Block): + def __init__(self, + g, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super(GCN, self).__init__() + self.inp_layer = gluon.nn.Dense(n_hidden, activation) + self.dropout = dropout + self.layers = gluon.nn.Sequential() + for i in range(n_layers): + self.layers.add(GCNLayer(g, n_hidden, activation, dropout)) + self.out_layer = gluon.nn.Dense(n_classes) + + + def forward(self, features): + emb_inp = [features, self.inp_layer(features)] + if self.dropout: + emb_inp[-1] = mx.nd.Dropout(emb_inp[-1], p=self.dropout) + h = mx.nd.concat(*emb_inp, dim=1) + for layer in self.layers: + h = layer(h) + h = self.out_layer(h) + return h + + +def evaluate(model, features, labels, mask): + pred = model(features).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + + +def main(args): + # load and preprocess dataset + data = load_data(args) + + if args.self_loop: + data.graph.add_edges_from([(i,i) for i in range(len(data.graph))]) + + features = mx.nd.array(data.features) + labels = mx.nd.array(data.labels) + train_mask = mx.nd.array(data.train_mask) + val_mask = mx.nd.array(data.val_mask) + test_mask = mx.nd.array(data.test_mask) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + cuda = False + ctx = mx.cpu(0) + else: + cuda = True + ctx = mx.gpu(args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # create GCN model + g = DGLGraph(data.graph) + # normalization + in_degs = g.in_degrees().astype('float32') + out_degs = g.out_degrees().astype('float32') + in_norm = mx.nd.power(in_degs, -0.5) + out_norm = mx.nd.power(out_degs, -0.5) + if cuda: + in_norm = in_norm.as_in_context(ctx) + out_norm = out_norm.as_in_context(ctx) + g.ndata['in_norm'] = mx.nd.expand_dims(in_norm, 1) + g.ndata['out_norm'] = mx.nd.expand_dims(out_norm, 1) + + model = GCN(g, + args.n_hidden, + n_classes, + args.n_layers, + 'relu', + args.dropout, + ) + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + # use optimizer + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(features) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + dur.append(time.time() - t0) + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + # test set accuracy + acc = evaluate(model, features, labels, test_mask) + print("Test accuracy {:.2%}".format(acc)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GCN') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--normalization", + choices=['sym','left'], default=None, + help="graph normalization types (default=None)") + parser.add_argument("--self-loop", action='store_true', + help="graph self-loop (default=False)") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + args = parser.parse_args() + + print(args) + + main(args) diff --git a/examples/pytorch/gcn/gcn.py b/examples/pytorch/gcn/gcn.py index 542ac6860bea..d31ec2454f9d 100644 --- a/examples/pytorch/gcn/gcn.py +++ b/examples/pytorch/gcn/gcn.py @@ -2,78 +2,105 @@ Semi-Supervised Classification with Graph Convolutional Networks Paper: https://arxiv.org/abs/1609.02907 Code: https://github.com/tkipf/gcn - -GCN with batch processing +GCN with SPMV specialization. """ -import argparse +import argparse, time, math import numpy as np -import time import torch import torch.nn as nn import torch.nn.functional as F from dgl import DGLGraph from dgl.data import register_data_args, load_data -def gcn_msg(edges): - return {'m' : edges.src['h']} -def gcn_reduce(nodes): - return {'h' : torch.sum(nodes.mailbox['m'], 1)} +def gcn_msg(edge): + msg = edge.src['h'] * edge.src['norm'] + return {'m': msg} + + +def gcn_reduce(node): + accum = torch.sum(node.mailbox['m'], 1) * node.data['norm'] + return {'h': accum} + class NodeApplyModule(nn.Module): - def __init__(self, in_feats, out_feats, activation=None): + def __init__(self, out_feats, activation=None, bias=True): super(NodeApplyModule, self).__init__() - self.linear = nn.Linear(in_feats, out_feats) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_feats)) + else: + self.bias = None self.activation = activation + self.reset_parameters() + + def reset_parameters(self): + if self.bias is not None: + stdv = 1. / math.sqrt(self.bias.size(0)) + self.bias.data.uniform_(-stdv, stdv) def forward(self, nodes): - # normalization by square root of dst degree - h = nodes.data['h'] * nodes.data['norm'] - h = self.linear(h) + h = nodes.data['h'] + if self.bias is not None: + h = h + self.bias if self.activation: h = self.activation(h) - return {'h' : h} + return {'h': h} -class GCN(nn.Module): + +class GCNLayer(nn.Module): def __init__(self, g, in_feats, - n_hidden, - n_classes, - n_layers, + out_feats, activation, - dropout): - super(GCN, self).__init__() + dropout, + bias=True): + super(GCNLayer, self).__init__() self.g = g - + self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) if dropout: self.dropout = nn.Dropout(p=dropout) else: self.dropout = 0. + self.node_update = NodeApplyModule(out_feats, activation, bias) + self.reset_parameters() - self.layers = nn.ModuleList() + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) - # input layer - self.layers.append(NodeApplyModule(in_feats, n_hidden, activation)) + def forward(self, h): + if self.dropout: + h = self.dropout(h) + self.g.ndata['h'] = torch.mm(h, self.weight) + self.g.update_all(gcn_msg, gcn_reduce, self.node_update) + h = self.g.ndata.pop('h') + return h +class GCN(nn.Module): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super(GCN, self).__init__() + self.layers = nn.ModuleList() + # input layer + self.layers.append(GCNLayer(g, in_feats, n_hidden, activation, dropout)) # hidden layers for i in range(n_layers - 1): - self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) - + self.layers.append(GCNLayer(g, n_hidden, n_hidden, activation, dropout)) # output layer - self.layers.append(NodeApplyModule(n_hidden, n_classes)) + self.layers.append(GCNLayer(g, n_hidden, n_classes, None, dropout)) def forward(self, features): - self.g.ndata['h'] = features - - for idx, layer in enumerate(self.layers): - # apply dropout - if idx > 0 and self.dropout: - self.g.ndata['h'] = self.dropout(self.g.ndata['h']) - # normalization by square root of src degree - self.g.ndata['h'] = self.g.ndata['h'] * self.g.ndata['norm'] - self.g.update_all(gcn_msg, gcn_reduce, layer) - return self.g.ndata.pop('h') + h = features + for layer in self.layers: + h = layer(h) + return h def evaluate(model, features, labels, mask): model.eval() @@ -88,7 +115,6 @@ def evaluate(model, features, labels, mask): def main(args): # load and preprocess dataset data = load_data(args) - features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.ByteTensor(data.train_mask) @@ -97,6 +123,16 @@ def main(args): in_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().item(), + val_mask.sum().item(), + test_mask.sum().item())) if args.gpu < 0: cuda = False @@ -133,6 +169,7 @@ def main(args): if cuda: model.cuda() + loss_fcn = torch.nn.CrossEntropyLoss() # use optimizer optimizer = torch.optim.Adam(model.parameters(), @@ -147,8 +184,7 @@ def main(args): t0 = time.time() # forward logits = model(features) - logp = F.log_softmax(logits, 1) - loss = F.nll_loss(logp[train_mask], labels[train_mask]) + loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() @@ -159,8 +195,8 @@ def main(args): acc = evaluate(model, features, labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " - "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), - acc, n_edges / np.mean(dur) / 1000)) + "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), + acc, n_edges / np.mean(dur) / 1000)) print() acc = evaluate(model, features, labels, test_mask)