Skip to content

Commit

Permalink
[Model] Fix GCN Normalization (dmlc#249)
Browse files Browse the repository at this point in the history
* WIP

* lr -> 0.01

* new cora dataset

* normalization code

* minor format change

* normalization factor for deg bucket
  • Loading branch information
lingfanyu authored and jermainewang committed Dec 5, 2018
1 parent 21255b6 commit 378c264
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 40 deletions.
76 changes: 59 additions & 17 deletions examples/pytorch/gcn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(self, in_feats, out_feats, activation=None):
self.activation = activation

def forward(self, nodes):
h = self.linear(nodes.data['h'])
# normalization by square root of dst degree
h = nodes.data['h'] * nodes.data['norm']
h = self.linear(h)
if self.activation:
h = self.activation(h)
return {'h' : h}
Expand All @@ -49,8 +51,10 @@ def __init__(self,
else:
self.dropout = 0.

self.layers = nn.ModuleList()

# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
self.layers.append(NodeApplyModule(in_feats, n_hidden, activation))

# hidden layers
for i in range(n_layers - 1):
Expand All @@ -62,22 +66,34 @@ def __init__(self,
def forward(self, features):
self.g.ndata['h'] = features

for layer in self.layers:
for idx, layer in enumerate(self.layers):
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda nodes: {'h': self.dropout(nodes.data['h'])})
if idx > 0 and self.dropout:
self.g.ndata['h'] = self.dropout(self.g.ndata['h'])
# normalization by square root of src degree
self.g.ndata['h'] = self.g.ndata['h'] * self.g.ndata['norm']
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.ndata.pop('h')

def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def main(args):
# load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args)

features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand All @@ -89,10 +105,24 @@ def main(args):
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

# create GCN model
# graph preprocess and calculate normalization factor
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
# add self loop
g.add_edges(g.nodes(), g.nodes())
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
if cuda:
norm = norm.cuda()
g.ndata['norm'] = norm.unsqueeze(1)

# create GCN model
model = GCN(g,
in_feats,
args.n_hidden,
Expand All @@ -105,17 +135,20 @@ def main(args):
model.cuda()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

# initialize graph
dur = []
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
loss = F.nll_loss(logp[train_mask], labels[train_mask])

optimizer.zero_grad()
loss.backward()
Expand All @@ -124,24 +157,33 @@ def main(args):
if epoch >= 3:
dur.append(time.time() - t0)

print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)

Expand Down
80 changes: 60 additions & 20 deletions examples/pytorch/gcn/gcn_spmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()

self.linear = nn.Linear(in_feats, out_feats)
nn.init.xavier_normal_(self.linear.weight)
self.activation = activation

def forward(self, nodes):
h = self.linear(nodes.data['h'])
# normalization by square root of dst degree
h = nodes.data['h'] * nodes.data['norm']
h = self.linear(h)
if self.activation:
h = self.activation(h)

return {'h': h}

class GCN(nn.Module):
Expand All @@ -46,8 +47,10 @@ def __init__(self,
else:
self.dropout = 0.

self.layers = nn.ModuleList()

# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
self.layers.append(NodeApplyModule(in_feats, n_hidden, activation))

# hidden layers
for i in range(n_layers - 1):
Expand All @@ -59,24 +62,35 @@ def __init__(self,
def forward(self, features):
self.g.ndata['h'] = features

for layer in self.layers:
for idx, layer in enumerate(self.layers):
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda nodes: {'h': self.dropout(nodes.data['h'])})
if idx > 0 and self.dropout:
self.g.ndata['h'] = self.dropout(self.g.ndata['h'])
# normalization by square root of src degree
self.g.ndata['h'] = self.g.ndata['h'] * self.g.ndata['norm']
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
return self.g.pop_n_repr('h')

def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def main(args):
# load and preprocess dataset
# Todo: adjacency normalization
data = load_data(args)

features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand All @@ -88,10 +102,24 @@ def main(args):
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

# create GCN model
# graph preprocess and calculate normalization factor
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
# add self loop
g.add_edges(g.nodes(), g.nodes())
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
if cuda:
norm = norm.cuda()
g.ndata['norm'] = norm.unsqueeze(1)

# create GCN model
model = GCN(g,
in_feats,
args.n_hidden,
Expand All @@ -104,17 +132,20 @@ def main(args):
model.cuda()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

# initialize graph
dur = []
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
loss = F.nll_loss(logp[train_mask], labels[train_mask])

optimizer.zero_grad()
loss.backward()
Expand All @@ -123,24 +154,33 @@ def main(args):
if epoch >= 3:
dur.append(time.time() - t0)

print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)

Expand Down
Loading

0 comments on commit 378c264

Please sign in to comment.