-
Notifications
You must be signed in to change notification settings - Fork 1
/
sweep_model_trainer.py
305 lines (241 loc) · 10.8 KB
/
sweep_model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
'''
This script will train a model and save it
'''
# Imports
import os
import copy
import wandb
import torch
import pickle
import random
from sam.sam import SAM
from data_setup import Data
from cifar10_sweep_config import sweep_config
from models.classes.first_layer_unitary_net import FstLayUniNet
# Functions
#-------------------------------------------------------------------------------------#
def initalize_config_defaults(sweep_config):
config_defaults = {}
for key in sweep_config['parameters']:
if list(sweep_config['parameters'][key].keys())[0] == 'values':
config_defaults.update({key : sweep_config['parameters'][key]["values"][0]})
else:
config_defaults.update({key : sweep_config['parameters'][key]["min"]})
wandb.init(config = config_defaults)
config = wandb.config
return config
def initalize_net(set_name, gpu, config):
# Network
net = FstLayUniNet(set_name, gpu =gpu,
U_filename = config.transformation,
model_name = config.model_name,
pretrained = config.pretrained)
# net.load_state_dict(torch.load('models/pretrained/CIFAR10/Ucifar10_mobilenetv2_x1_4_w_acc_78.pt', map_location=torch.device('cpu')))
# Return network
return net.cuda() if gpu == True else net
def initalize_optimizer(data, net, config):
if config.optimizer=='sgd':
if config.use_SAM:
optimizer = SAM(net.parameters(), torch.optim.SGD, lr=config.learning_rate,
momentum=config.momentum, weight_decay=config.weight_decay)
else:
optimizer = torch.optim.SGD(net.parameters(), lr=config.learning_rate, momentum=config.momentum,
weight_decay=config.weight_decay)
if config.optimizer=='nesterov':
if config.use_SAM:
optimizer = SAM(net.parameters(), torch.optim.SGD, lr=config.learning_rate, momentum=config.momentum,
weight_decay=config.weight_decay, nesterov=True)
else:
optimizer = torch.optim.SGD(net.parameters(), lr=config.learning_rate, momentum=config.momentum,
weight_decay=config.weight_decay, nesterov=True)
if config.optimizer=='adam':
if config.use_SAM:
optimizer = SAM(net.parameters(), torch.optim.Adam, lr=config.learning_rate,
weight_decay=config.weight_decay)
else:
optimizer = torch.optim.Adam(net.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
if config.optimizer=='adadelta':
if config.use_SAM:
optimizer = SAM(net.parameters(), torch.optim.Adadelta, **{"lr" : config.learning_rate,
"weight_decay" : config.weight_decay,
"rho" : config.momentum})
else:
optimizer = torch.optim.Adadelta(net.parameters(), lr=config.learning_rate,
weight_decay=config.weight_decay, rho=config.momentum)
return optimizer
def initalize_scheduler(optimizer, data, config):
if config.scheduler == "Cosine Annealing":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = int(config.epochs))
return scheduler
def initalize_criterion(config):
# Setup Criterion
if config.criterion=="mse":
criterion = torch.nn.MSELoss()
elif config.criterion=="cross_entropy":
criterion = torch.nn.CrossEntropyLoss()
else:
print("Invalid criterion setting in sweep_config.py")
exit()
return criterion
def train(data, save_model):
# Weights and Biases Setup
config = initalize_config_defaults(sweep_config)
#Get training data
train_loader = data.get_train_loader(config.batch_size)
wandb.log({ "Data Augmentation" : data.data_augment})
# Initialize Network
net = initalize_net(data.set_name, data.gpu, config)
net.train(True)
# Setup Optimzier and Criterion
optimizer = initalize_optimizer(data, net, config)
criterion = initalize_criterion(config)
if config.scheduler is not None:
scheduler = initalize_scheduler(optimizer, data, config)
# Loop for epochs
correct = 0
total_tested = 0
for epoch in range(int(config.epochs)):
epoch_loss = 0
for i, batch_data in enumerate(train_loader, 0):
# Get labels and inputs from train_loader
inputs, labels = batch_data
# One Hot the labels
orginal_labels = copy.copy(labels).long()
labels = torch.eq(labels.view(labels.size(0), 1), torch.arange(10).reshape(1, 10).repeat(labels.size(0), 1)).float()
# Push to gpu
if data.gpu:
orginal_labels = orginal_labels.cuda()
inputs, labels = inputs.cuda(), labels.cuda()
#Set the parameter gradients to zero
optimizer.zero_grad()
#Forward pass
with torch.set_grad_enabled(True):
# SAM optimizer needs closure function to
# reevaluate the loss function many times
def closure():
#Set the parameter gradients to zero
optimizer.zero_grad()
# Forward pass
outputs = net(inputs)
# Calculate loss
if config.criterion == "cross_entropy":
loss = criterion(outputs, orginal_labels)
else:
loss = criterion(outputs, labels) # Calculate loss
# Backward pass and optimize
loss.backward()
return loss
# Rerun forward pass and loss once SAM is done
# Forward pass
outputs = net(inputs)
# Calculate loss
if config.criterion == "cross_entropy":
loss = criterion(outputs, orginal_labels)
else:
loss = criterion(outputs, labels) # Calculate loss
_, predictions = torch.max(outputs, 1)
correct += (predictions == orginal_labels).sum()
total_tested += labels.size(0)
epoch_loss += loss.item()
# Update weights
loss.backward()
if config.use_SAM:
optimizer.step(closure)
else:
optimizer.step()
if config.scheduler is not None:
scheduler.step()
# Display
if epoch % 2 == 0:
val_loss, val_acc = test(net, data, config)
if ((val_loss > epoch_loss/len(data.train_set)) and (epoch > 10)) or (epoch > 0.9*config.epochs):
data.data_augment = True
data.train_set = data.get_trainset()
train_loader = data.get_train_loader(config.batch_size)
wandb.log({ "Data Augmentation" : data.data_augment})
net.train(True)
print("Epoch: ", epoch + 1, "\tTrain Loss: ", epoch_loss/len(data.train_set), "\tVal Loss: ", val_loss)
wandb.log({ "epoch" : epoch,
"Train Loss" : epoch_loss/len(data.train_set),
"Train Acc" : correct/total_tested,
"Val Loss" : val_loss,
"Val Acc" : val_acc})
# Test
val_loss, val_acc = test(net, data, config)
wandb.log({"epoch" : epoch,
"Train Loss" : epoch_loss/len(data.train_set),
"Train Acc" : correct/total_tested,
"Val Loss" : val_loss,
"Val Acc" : val_acc})
# Save Model
if save_model:
# Define File Names
if net.U is not None:
filename = "U" + str(config.model_name) + "_w_acc_" + str(int(round(val_acc.item() * 100, 3))) + ".pt"
else:
filename = "U" + str(config.model_name) + "_w_acc_" + str(int(round(val_acc.item() * 100, 3))) + ".pt"
# Save Models
torch.save(net.state_dict(), "models/pretrained/" + set_name + "/" + filename)
# Save U
if net.U is not None:
torch.save(net.U, "models/pretrained/" + set_name + "/" + str(config.transformation) + "_for_" + set_name + filename)
def test(net, data, config):
# Set to test mode
net.eval()
#Create loss functions
criterion = initalize_criterion(config)
# Initialize
total_loss = 0
correct = 0
total_tested = 0
# Test data in test loader
for i, batch_data in enumerate(data.test_loader, 0):
# Get labels and inputs from train_loader
inputs, labels = batch_data
# One Hot the labels
orginal_labels = copy.copy(labels).long()
labels = torch.eq(labels.view(labels.size(0), 1), torch.arange(10).reshape(1, 10).repeat(labels.size(0), 1)).float()
# Push to gpu
if gpu:
orginal_labels = orginal_labels.cuda()
inputs, labels = inputs.cuda(), labels.cuda()
#Forward pass
outputs = net(inputs)
if config.criterion == "cross_entropy":
loss = criterion(outputs, orginal_labels) # Calculate loss
else:
loss = criterion(outputs, labels) # Calculate loss
# Update runnin sum
_, predictions = torch.max(outputs, 1)
correct += (predictions == orginal_labels).sum()
total_tested += labels.size(0)
total_loss += loss.item()
# Test Loss
test_loss = (total_loss/len(data.test_set))
test_acc = correct / total_tested
return test_loss, test_acc
#-------------------------------------------------------------------------------------#
# Main
if __name__ == "__main__":
# Hyperparameters
gpu = True
save_model = True
project_name = "CIFAR10"
set_name = "CIFAR10"
# seed = 100
os.environ['WANDB_MODE'] = 'dryrun'
# Push to GPU if necessary
if gpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# Declare seed and initalize network
# torch.manual_seed(seed)
# Load data
data = Data(gpu = gpu, set_name = set_name, data_augment = False) #, desired_image_size = 224, test_batch_size = 32)
print(set_name + " is Loaded")
# Run the sweep
config = initalize_config_defaults(sweep_config)
net = initalize_net(data.set_name, data.gpu, config)
print(test(net, data, config))
# sweep_id = wandb.sweep(sweep_config, entity="naddeok", project=project_name)
# wandb.agent(sweep_id, function=lambda: train(data, save_model))