-
Notifications
You must be signed in to change notification settings - Fork 77
/
train.py
168 lines (142 loc) · 6.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import random
import numpy as np
import time
import vgg
import transformer
import utils
# GLOBAL SETTINGS
TRAIN_IMAGE_SIZE = 256
DATASET_PATH = "dataset"
NUM_EPOCHS = 1
STYLE_IMAGE_PATH = "images/mosaic.jpg"
BATCH_SIZE = 4
CONTENT_WEIGHT = 17 # 17
STYLE_WEIGHT = 50 # 25
ADAM_LR = 0.001
SAVE_MODEL_PATH = "models/"
SAVE_IMAGE_PATH = "images/out/"
SAVE_MODEL_EVERY = 500 # 2,000 Images with batch size 4
SEED = 35
PLOT_LOSS = 1
def train():
# Seeds
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
# Device
device = ("cuda" if torch.cuda.is_available() else "cpu")
# Dataset and Dataloader
transform = transforms.Compose([
transforms.Resize(TRAIN_IMAGE_SIZE),
transforms.CenterCrop(TRAIN_IMAGE_SIZE),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Load networks
TransformerNetwork = transformer.TransformerNetwork().to(device)
VGG = vgg.VGG16().to(device)
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).to(device)
style_image = utils.load_image(STYLE_IMAGE_PATH)
style_tensor = utils.itot(style_image).to(device)
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = utils.gram(value)
# Optimizer settings
optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)
# Loss trackers
content_loss_history = []
style_loss_history = []
total_loss_history = []
batch_content_loss_sum = 0
batch_style_loss_sum = 0
batch_total_loss_sum = 0
# Optimization/Training Loop
batch_count = 1
start_time = time.time()
for epoch in range(NUM_EPOCHS):
print("========Epoch {}/{}========".format(epoch+1, NUM_EPOCHS))
for content_batch, _ in train_loader:
# Get current batch size in case of odd batch sizes
curr_batch_size = content_batch.shape[0]
# Free-up unneeded cuda memory
torch.cuda.empty_cache()
# Zero-out Gradients
optimizer.zero_grad()
# Generate images and get features
content_batch = content_batch[:,[2,1,0]].to(device)
generated_batch = TransformerNetwork(content_batch)
content_features = VGG(content_batch.add(imagenet_neg_mean))
generated_features = VGG(generated_batch.add(imagenet_neg_mean))
# Content Loss
MSELoss = nn.MSELoss().to(device)
content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2'])
batch_content_loss_sum += content_loss
# Style Loss
style_loss = 0
for key, value in generated_features.items():
s_loss = MSELoss(utils.gram(value), style_gram[key][:curr_batch_size])
style_loss += s_loss
style_loss *= STYLE_WEIGHT
batch_style_loss_sum += style_loss.item()
# Total Loss
total_loss = content_loss + style_loss
batch_total_loss_sum += total_loss.item()
# Backprop and Weight Update
total_loss.backward()
optimizer.step()
# Save Model and Print Losses
if (((batch_count-1)%SAVE_MODEL_EVERY == 0) or (batch_count==NUM_EPOCHS*len(train_loader))):
# Print Losses
print("========Iteration {}/{}========".format(batch_count, NUM_EPOCHS*len(train_loader)))
print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum/batch_count))
print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum/batch_count))
print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum/batch_count))
print("Time elapsed:\t{} seconds".format(time.time()-start_time))
# Save Model
checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str(batch_count-1) + ".pth"
torch.save(TransformerNetwork.state_dict(), checkpoint_path)
print("Saved TransformerNetwork checkpoint file at {}".format(checkpoint_path))
# Save sample generated image
sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0)
sample_image = utils.ttoi(sample_tensor.clone().detach())
sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str(batch_count-1) + ".png"
utils.saveimg(sample_image, sample_image_path)
print("Saved sample tranformed image at {}".format(sample_image_path))
# Save loss histories
content_loss_history.append(batch_total_loss_sum/batch_count)
style_loss_history.append(batch_style_loss_sum/batch_count)
total_loss_history.append(batch_total_loss_sum/batch_count)
# Iterate Batch Counter
batch_count+=1
stop_time = time.time()
# Print loss histories
print("Done Training the Transformer Network!")
print("Training Time: {} seconds".format(stop_time-start_time))
print("========Content Loss========")
print(content_loss_history)
print("========Style Loss========")
print(style_loss_history)
print("========Total Loss========")
print(total_loss_history)
# Save TransformerNetwork weights
TransformerNetwork.eval()
TransformerNetwork.cpu()
final_path = SAVE_MODEL_PATH + "transformer_weight.pth"
print("Saving TransformerNetwork weights at {}".format(final_path))
torch.save(TransformerNetwork.state_dict(), final_path)
print("Done saving final model")
# Plot Loss Histories
if (PLOT_LOSS):
utils.plot_loss_hist(content_loss_history, style_loss_history, total_loss_history)
train()