forked from Sohl-Dickstein/Diffusion-Probabilistic-Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
198 lines (178 loc) · 8.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import argparse
import numpy as np
import os
import warnings
import theano
import theano.tensor as T
from theano.tensor.shared_randomstreams import RandomStreams
from blocks.algorithms import (RMSProp, GradientDescent, CompositeRule,
RemoveNotFinite)
from blocks.extensions import FinishAfter, Timing, Printing
from blocks.extensions.monitoring import (TrainingDataMonitoring,
DataStreamMonitoring)
from blocks.extensions.saveload import Checkpoint
from blocks.extensions.training import SharedVariableModifier
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_dropout
from blocks.main_loop import MainLoop
import blocks.model
from blocks.roles import INPUT, PARAMETER
from fuel.streams import DataStream
from fuel.schemes import ShuffledScheme
from fuel.transformers import Flatten, ScaleAndShift
import extensions
import model
import util
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=512, type=int,
help='Batch size')
parser.add_argument('--lr', default=1e-3, type=float,
help='Initial learning rate. ' + \
'Will be decayed until it\'s 1e-5.')
parser.add_argument('--resume_file', default=None, type=str,
help='Name of saved model to continue training')
parser.add_argument('--suffix', default='', type=str,
help='Optional descriptive suffix for model')
parser.add_argument('--output-dir', type=str, default='./',
help='Output directory to store trained models')
parser.add_argument('--ext-every-n', type=int, default=25,
help='Evaluate training extensions every N epochs')
parser.add_argument('--model-args', type=str, default='',
help='Dictionary string to be eval()d containing model arguments.')
parser.add_argument('--dropout_rate', type=float, default=0.,
help='Rate to use for dropout during training+testing.')
parser.add_argument('--dataset', type=str, default='MNIST',
help='Name of dataset to use.')
parser.add_argument('--plot_before_training', type=bool, default=False,
help='Save diagnostic plots at epoch 0, before any training.')
args = parser.parse_args()
model_args = eval('dict(' + args.model_args + ')')
print model_args
if not os.path.exists(args.output_dir):
raise IOError("Output directory '%s' does not exist. "%args.output_dir)
return args, model_args
if __name__ == '__main__':
# TODO batches_per_epoch should not be hard coded
batches_per_epoch = 500
import sys
sys.setrecursionlimit(10000000)
args, model_args = parse_args()
if args.resume_file is not None:
print "Resuming training from " + args.resume_file
from blocks.scripts import continue_training
continue_training(args.resume_file)
## load the training data
if args.dataset == 'MNIST':
from fuel.datasets import MNIST
dataset_train = MNIST(['train'], sources=('features',))
dataset_test = MNIST(['test'], sources=('features',))
n_colors = 1
spatial_width = 28
elif args.dataset == 'CIFAR10':
from fuel.datasets import CIFAR10
dataset_train = CIFAR10(['train'], sources=('features',))
dataset_test = CIFAR10(['test'], sources=('features',))
n_colors = 3
spatial_width = 32
elif args.dataset == 'IMAGENET':
from imagenet_data import IMAGENET
spatial_width = 128
dataset_train = IMAGENET(['train'], width=spatial_width)
dataset_test = IMAGENET(['test'], width=spatial_width)
n_colors = 3
else:
raise ValueError("Unknown dataset %s."%args.dataset)
train_stream = Flatten(DataStream.default_stream(dataset_train,
iteration_scheme=ShuffledScheme(
examples=dataset_train.num_examples,
batch_size=args.batch_size)))
test_stream = Flatten(DataStream.default_stream(dataset_test,
iteration_scheme=ShuffledScheme(
examples=dataset_test.num_examples,
batch_size=args.batch_size))
)
shp = next(train_stream.get_epoch_iterator())[0].shape
# make the training data 0 mean and variance 1
# TODO compute mean and variance on full dataset, not minibatch
Xbatch = next(train_stream.get_epoch_iterator())[0]
scl = 1./np.sqrt(np.mean((Xbatch-np.mean(Xbatch))**2))
shft = -np.mean(Xbatch*scl)
# scale is applied before shift
train_stream = ScaleAndShift(train_stream, scl, shft)
test_stream = ScaleAndShift(test_stream, scl, shft)
baseline_uniform_noise = 1./255. # appropriate for MNIST and CIFAR10 Fuel datasets, which are scaled [0,1]
uniform_noise = baseline_uniform_noise/scl
## initialize the model
dpm = model.DiffusionModel(spatial_width, n_colors, uniform_noise=uniform_noise, **model_args)
dpm.initialize()
## set up optimization
features = T.matrix('features', dtype=theano.config.floatX)
cost = dpm.cost(features)
blocks_model = blocks.model.Model(cost)
cg_nodropout = ComputationGraph(cost)
if args.dropout_rate > 0:
# DEBUG this triggers an error on my machine
# apply dropout to all the input variables
inputs = VariableFilter(roles=[INPUT])(cg_nodropout.variables)
# dropconnect
# inputs = VariableFilter(roles=[PARAMETER])(cg_nodropout.variables)
cg = apply_dropout(cg_nodropout, inputs, args.dropout_rate)
else:
cg = cg_nodropout
step_compute = RMSProp(learning_rate=args.lr, max_scaling=1e10)
algorithm = GradientDescent(step_rule=CompositeRule([RemoveNotFinite(),
step_compute]),
parameters=cg.parameters, cost=cost)
extension_list = []
extension_list.append(
SharedVariableModifier(step_compute.learning_rate,
extensions.decay_learning_rate,
after_batch=False,
every_n_batches=batches_per_epoch, ))
extension_list.append(FinishAfter(after_n_epochs=100001))
## logging of test set performance
extension_list.append(extensions.LogLikelihood(dpm, test_stream, scl,
every_n_batches=args.ext_every_n*batches_per_epoch, before_training=False))
## set up logging
extension_list.extend([Timing(), Printing()])
model_dir = util.create_log_dir(args, dpm.name + '_' + args.dataset)
model_save_name = os.path.join(model_dir, 'model.pkl')
extension_list.append(
Checkpoint(model_save_name, every_n_batches=args.ext_every_n*batches_per_epoch, save_separately=['log']))
# generate plots
extension_list.append(extensions.PlotMonitors(model_dir,
every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
test_batch = next(test_stream.get_epoch_iterator())[0]
extension_list.append(extensions.PlotSamples(dpm, algorithm, test_batch, model_dir,
every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
internal_state = dpm.internal_state(features)
train_batch = next(train_stream.get_epoch_iterator())[0]
# extension_list.append(
# extensions.PlotInternalState(dpm, blocks_model, internal_state, features, train_batch, model_dir,
# every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
extension_list.append(
extensions.PlotParameters(dpm, blocks_model, model_dir,
every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
# extension_list.append(
# extensions.PlotGradients(dpm, blocks_model, algorithm, train_batch, model_dir,
# every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
# # console monitors
# # DEBUG -- incorporating train_monitor or test_monitor triggers a large number of
# # float64 vs float32 GPU warnings, although monitoring still works. I think this is a Blocks
# # bug. Uncomment this code to have more information during debugging/development.
# train_monitor_vars = [cost]
# norms, grad_norms = util.get_norms(blocks_model, algorithm.gradients)
# train_monitor_vars.extend(norms + grad_norms)
# train_monitor = TrainingDataMonitoring(
# train_monitor_vars, prefix='train', after_batch=True, before_training=True)
# extension_list.append(train_monitor)
# test_monitor_vars = [cost]
# test_monitor = DataStreamMonitoring(test_monitor_vars, test_stream, prefix='test', before_training=True)
# extension_list.append(test_monitor)
## train
sys.setrecursionlimit(10000000)
main_loop = MainLoop(model=blocks_model, algorithm=algorithm,
data_stream=train_stream,
extensions=extension_list)
main_loop.run()