forked from singh-hrituraj/PixelCNN-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
151 lines (116 loc) · 4.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
'''
Code by Hrituraj Singh
Indian Institute of Technology Roorkee
'''
import sys
import os
import time
import torch
from torch import optim
from torch.utils import data
from torch.autograd import Variable
import torch.nn as nn
from utils import *
from Model import PixelCNN
import matplotlib.pyplot as plt
def main(config_file):
config = parse_config(config_file)
data_ = config['data']
network = config['network']
path = data_.get('path', 'Data') #Path where the data after loading is to be saved
data_name = data_.get('data_name','MNIST') #What data type is to be loaded ex - MNIST, CIFAR
batch_size = data_.get('batch_size', 144)
layers = network.get('no_layers', 8) #Number of layers in the network
kernel = network.get('kernel', 7) #Kernel size
channels = network.get('channels', 64) #Depth of the intermediate layers
epochs = network.get('epochs', 25) #No of epochs
save_path = network.get('save_path', 'Models_36_epochs250') #path where the models are to be saved
lr = 1e-4
#Loading Data
if (data_name=='MNIST'):
train, test = get_MNIST(path)
N_train = len(train)
N_test = len(test)
train = data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers =1, pin_memory = True)
test = data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers =1, pin_memory = True)
#Defining the model and training it on loss function
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = PixelCNN().to(device)
if torch.cuda.device_count() > 1: # If more than one GPU available, accelerate the training using multiple GPUs
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = nn.DataParallel(net)
optimizer = optim.Adam(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
total_train_loss = []
total_test_loss = []
time_start = time.time()
print('Training Started')
print(len(train))
for i in range(epochs):
net.train(True)
step = 0
train_loss_sum = 0.0
for images, labels in train:
target = Variable(images[:,0,:,:]*255).long()
images = images.to(device)
target = target.to(device)
optimizer.zero_grad()
output = net(images)
# loss = criterion(output, target)
# print(images.shape)
# print(output.shape)
loss = discretized_mix_logistic_loss_1d (images, output)
loss.backward()
optimizer.step()
train_loss_sum += loss.item()
step+=1
# if(step%100 == 0):
# print('Epoch:'+str(i)+'\t'+ str(step) +'\t Iterations Complete \t'+'loss: ', loss.item()/1000.0)
# print('Epoch:'+str(i)+'\t'+ str(step) +'\t Iterations Complete \t'+'loss: ', loss.item()/1000.0)
train_loss_mean = train_loss_sum / N_train
total_train_loss.append(train_loss_mean)
print('Epoch:'+str(i)+'\t'+ str(step) +'\t Iterations Complete \t'+'train_loss: ', train_loss_mean)
# print('Epoch: '+str(i)+' Over!')
net.eval()
test_loss_sum = 0.
for images, labels in test:
target = Variable(images[:,0,:,:]*255).long()
images = images.to(device)
target = target.to(device)
output = net(images)
loss = discretized_mix_logistic_loss_1d(images, output)
test_loss_sum += loss.item()
test_loss_mean = test_loss_sum / N_test
total_test_loss.append(test_loss_mean)
print('Epoch:'+str(i)+'\t'+ str(step) +'\t Iterations Complete \t'+'test_loss: ', test_loss_mean)
print('Epoch: '+str(i)+' Over!')
#Saving the model
if not os.path.exists(save_path):
os.makedirs(save_path)
print("Saving Checkpoint!")
if(i==epochs-1):
torch.save(net.state_dict(), save_path+'/Model_Checkpoint_'+'Last'+'.pt')
else:
torch.save(net.state_dict(), save_path+'/Model_Checkpoint_'+str(i)+'.pt')
print('Checkpoint Saved')
print('Training Finished! Time Taken: ', time.time() - time_start)
# import pdb
# pdb.set_trace()
x1 = range(0, epochs)
x2 = range(0, epochs)
y1 = total_train_loss
y2 = total_test_loss
plt.subplot(2,1,1)
plt.plot(x1, y1, 'o')
plt.title('LOSS')
plt.ylabel('Train loss')
plt.subplot(2,1,2)
plt.plot(x2, y2, 'o')
plt.xlabel('Epochs')
plt.ylabel('Test loss')
# plt.show()
plt.savefig('loss_fig_250.jpg')
if __name__=="__main__":
config_file = sys.argv[1]
assert os.path.exists(config_file), "Configuration file does not exit!"
main(config_file)