forked from wengong-jin/hgraph2graph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_translator.py
executable file
·110 lines (86 loc) · 3.97 KB
/
train_translator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
import numpy as np
import argparse
from hgraph import *
import rdkit
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--save_dir', required=True)
parser.add_argument('--load_epoch', type=int, default=-1)
parser.add_argument('--conditional', action='store_true')
parser.add_argument('--cond_size', type=int, default=4)
parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=270)
parser.add_argument('--embed_size', type=int, default=270)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--latent_size', type=int, default=4)
parser.add_argument('--depthT', type=int, default=20)
parser.add_argument('--depthG', type=int, default=20)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--clip_norm', type=float, default=20.0)
parser.add_argument('--beta', type=float, default=0.3)
parser.add_argument('--epoch', type=int, default=12)
parser.add_argument('--anneal_rate', type=float, default=0.9)
parser.add_argument('--print_iter', type=int, default=50)
parser.add_argument('--save_iter', type=int, default=-1)
args = parser.parse_args()
print(args)
vocab = [x.strip("\r\n ").split() for x in open(args.vocab)]
args.vocab = PairVocab(vocab)
if args.conditional:
model = HierCondVGNN(args).cuda()
else:
model = HierVGNN(args).cuda()
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)
if args.load_epoch >= 0:
model.load_state_dict(torch.load(args.save_dir + "/model." + str(args.load_epoch)))
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate)
param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))
total_step = 0
beta = args.beta
meters = np.zeros(6)
for epoch in range(args.load_epoch + 1, args.epoch):
dataset = DataFolder(args.train, args.batch_size)
for batch in dataset:
total_step += 1
batch = batch + (beta,)
model.zero_grad()
loss, kl_div, wacc, iacc, tacc, sacc = model(*batch)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
optimizer.step()
meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100])
if total_step % args.print_iter == 0:
meters /= args.print_iter
print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model)))
sys.stdout.flush()
meters *= 0
if args.save_iter >= 0 and total_step % args.save_iter == 0:
n_iter = total_step // args.save_iter - 1
torch.save(model.state_dict(), args.save_dir + "/model." + str(n_iter))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
del dataset
if args.save_iter == -1:
torch.save(model.state_dict(), args.save_dir + "/model." + str(epoch))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])