-
Notifications
You must be signed in to change notification settings - Fork 1
/
run-gqn.py
249 lines (198 loc) · 9.78 KB
/
run-gqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
run-gqn.py
Script to train the a GQN on the Shepard-Metzler dataset
in accordance to the hyperparameter settings described in
the supplementary materials of the paper.
"""
import os
import random
import math
from argparse import ArgumentParser
# Torch
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
# TensorboardX
from tensorboardX import SummaryWriter
# Ignite
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage
from building_blocks.gqn import GenerativeQueryNetwork
from building_blocks.annealer import Annealer
from building_blocks.training import partition
from data.datasets import ShepardMetzler, Scene
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
# Random seeding
random.seed(99)
torch.manual_seed(99)
if cuda: torch.cuda.manual_seed(99)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
parser = ArgumentParser(description='Generative Query Network on Shepard Metzler Example')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs run (default: 200)')
parser.add_argument('--batch_size', type=int, default=36, help='multiple of batch size (default: 1)')
parser.add_argument('--data_dir', type=str, help='location of data', default="data/shepard_metzler_5_parts-torch")
parser.add_argument('--log_dir', type=str, help='location of logging', default="log")
parser.add_argument('--checkpoint_dir', type=str, help='location of checkpoints', default="checkpoints")
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--data_parallel', type=bool, help='whether to parallelise based on data (default: False)', default=False)
parser.add_argument('--max_n', type=int, help="maximum number of examples in dataset (default: -1, ie all)", default=-1)
parser.add_argument('--eval_n', type=int, help="evaluate every n iterations", default=1000)
parser.add_argument('--resume', type=str, help="iteration to resume at. Helps locate the file in args.checkpoint_dir", default="")
args = parser.parse_args()
# Create model and optimizer
model = GenerativeQueryNetwork(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8).to(device)
model = nn.DataParallel(model) if args.data_parallel else model
# TODO: log the throughput improvements if we use --data_parallel=True
optimizer = torch.optim.Adam(model.parameters(), lr=5 * 10 ** (-5))
# Rate annealing schemes
sigma_scheme = Annealer(2.0, 0.7, 2e5)
mu_scheme = Annealer(5 * 10 ** (-4), 5 * 10 ** (-5), 1.6e6)
# Load the dataset
kwargs = {'max_n': args.max_n}
train_dataset = ShepardMetzler(root_dir=args.data_dir, **kwargs)
valid_dataset = ShepardMetzler(root_dir=args.data_dir, train=False, **kwargs)
kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
def step(engine, batch):
model.train()
x, v = batch
x, v = x.to(device), v.to(device)
x, v, x_q, v_q = partition(x, v)
# Reconstruction, representation and divergence
x_mu, _, kl = model(x, v, x_q, v_q)
# Log likelihood
sigma = next(sigma_scheme)
ll = Normal(x_mu, sigma).log_prob(x_q)
likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3]))
kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3]))
# Evidence lower bound
elbo = likelihood - kl_divergence
loss = -elbo
loss.backward()
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
# Anneal learning rate
mu = next(mu_scheme)
i = engine.state.iteration
for group in optimizer.param_groups:
group["lr"] = mu * math.sqrt(1 - 0.999 ** i) / (1 - 0.9 ** i)
return {"elbo": elbo.item(), "kl": kl_divergence.item(), "sigma": sigma, "mu": mu }
# Trainer and metrics
trainer = Engine(step)
metric_names = ["elbo", "kl", "sigma", "mu"]
RunningAverage(output_transform=lambda x: x["elbo"]).attach(trainer, "elbo")
RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
RunningAverage(output_transform=lambda x: x["sigma"]).attach(trainer, "sigma")
RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
ProgressBar().attach(trainer, metric_names=metric_names)
# Model checkpointing
checkpoint_handler = ModelCheckpoint(args.checkpoint_dir, "checkpoint", save_interval=args.eval_n, n_saved=2, require_empty=False, save_as_state_dict=False)
trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=checkpoint_handler,
to_save={'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'mu': mu_scheme, 'sigma': sigma_scheme})
archiver = ModelCheckpoint(args.checkpoint_dir, "archive", save_interval=args.eval_n * 10, n_saved=20,
require_empty=False, save_as_state_dict=False)
trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=archiver,
to_save={'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
'mu': mu_scheme, 'sigma': sigma_scheme})
timer = Timer(average=True).attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
# Tensorbard writer
writer = SummaryWriter(logdir=args.log_dir)
def savedIn(filename):
name = os.path.join(args.checkpoint_dir, filename)
if os.path.isfile(name):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(name)
return checkpoint
def updateAnnealers(steps):
print(f"Fast forwarding rates to {steps} steps")
for i in range(steps):
# Rate annealing schemes
next(sigma_scheme)
next(mu_scheme)
def loadCheckpoint():
global mu_scheme, sigma_scheme
if args.resume == "":
return
files = {x: "checkpoint_{}_{}.pth".format(x, args.resume)
for x in ["model", "optimizer", "mu", "sigma"]}
loaded = {x: savedIn(file) for x, file in files.items()}
mu = sigma = None
for name, k in loaded.items():
if not k:
print("Unable to load {}".format(files[name]))
continue
if name == "model":
model.load_state_dict(k)
elif name == "optimizer":
optimizer.load_state_dict(k)
elif name == "mu":
mu = k
elif name == "sigma":
sigma = k
# If we don't have a saved model here, don't do anything else
if not loaded['model']:
return
if mu is None or sigma is None:
updateAnnealers(int(args.resume))
else:
mu_scheme = mu
sigma_scheme = sigma
print("=> LOADING CHECKPOINT_{} SUCCESS!".format(args.resume))
@trainer.on(Events.ITERATION_COMPLETED)
def log_metrics(engine):
for key, value in engine.state.metrics.items():
writer.add_scalar("training/{}".format(key), value, engine.state.iteration)
@trainer.on(Events.ITERATION_COMPLETED)
def save_images(engine):
if engine.state.iteration % args.eval_n == 0:
with torch.no_grad():
x, v = engine.state.batch
x, v = x.to(device), v.to(device)
x, v, x_q, v_q = partition(x, v)
x_mu, r, _ = model(x, v, x_q, v_q)
r = r.view(-1, 1, 16, 16)
# Send to CPU
x_mu = x_mu.detach().cpu().float()
r = r.detach().cpu().float()
writer.add_image("representation", make_grid(r), engine.state.iteration)
writer.add_image("reconstruction", make_grid(x_mu), engine.state.iteration)
@trainer.on(Events.ITERATION_COMPLETED)
def validate(engine):
if engine.state.iteration % args.eval_n == 0:
model.eval()
with torch.no_grad():
x, v = next(iter(valid_loader))
x, v = x.to(device), v.to(device)
x, v, x_q, v_q = partition(x, v)
# Reconstruction, representation and divergence
x_mu, _, kl = model(x, v, x_q, v_q)
# Validate at last sigma
ll = Normal(x_mu, sigma_scheme.recent).log_prob(x_q)
likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3]))
kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3]))
# Evidence lower bound
elbo = likelihood - kl_divergence
writer.add_scalar("validation/elbo", elbo.item(), engine.state.iteration)
writer.add_scalar("validation/kl", kl_divergence.item(), engine.state.iteration)
@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
writer.close()
engine.terminate()
if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
import warnings
warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
checkpoint_handler(engine, { 'model_exception': model })
else: raise e
loadCheckpoint()
trainer.run(train_loader, args.n_epochs)
writer.close()