-
Notifications
You must be signed in to change notification settings - Fork 1
/
training.py
144 lines (119 loc) · 4.56 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Standard library imports
import os
import json
import math
import random
# Third party imports
import glob
import json
import pickle
from PIL import Image, ImageDraw
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from torch.optim import lr_scheduler
from torchsummary import summary
# Local imports
import config
import inference
from modules import model
from modules import train
from modules import util
DEVICE = config.DEVICE
CURRENT_FREEZE_EPOCH = 0
CURRENT_UNFREEZE_EPOCH = 0
BEST_LOSS = 4
def training(args):
# declaring global variables
global BEST_LOSS
global CURRENT_FREEZE_EPOCH
global CURRENT_UNFREEZE_EPOCH
# steps for preparing and splitting the data for training
with open(config.ARGS.positive_image_dict, "rb") as f:
positive_image_dict = pickle.load(f)
IMAGES = list(positive_image_dict.keys())
# splitting the images to train and validation set
random.shuffle(IMAGES)
train_images = IMAGES[:math.floor(len(IMAGES) * 0.8)]
val_images = IMAGES[math.ceil(len(IMAGES) * 0.8):]
# loading the pretrained model and changing the dense layer. Initially the convolution layers will be freezed
base_model = models.resnet50(pretrained=True).to(DEVICE)
for param in base_model.parameters():
param.requires_grad = False
num_ftrs = base_model.fc.in_features
base_model.fc = nn.Sequential(nn.Linear(num_ftrs, 1024), nn.Linear(1024, 512), nn.Linear(512, 256))
base_model = base_model.to(DEVICE)
tnet = model.Tripletnet(base_model).to(DEVICE)
if(config.ARGS.resume):
try:
CURRENT_FREEZE_EPOCH, CURRENT_UNFREEZE_EPOCH, BEST_LOSS, tnet = util.load_checkpoint(config.ARGS.checkpoint_name, tnet)
except:
print("not able to load checkpoint because of non-availability")
# Initializing the loss function and optimizer
criterion = torch.nn.MarginRankingLoss(margin=config.TRIPLET_MARGIN)
optimizer = optim.SGD(tnet.parameters(), lr=config.LR, momentum=config.MOMENTUM)
# # Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# Printing total number of parameters
n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
print(' + Number of params: {}'.format(n_parameters))
# training for the first few iteration of freeze layers where apart from the dense layers all the other layres are frozen
if(CURRENT_FREEZE_EPOCH < config.FREEZE_EPOCHS):
for epoch in range(CURRENT_FREEZE_EPOCH + 1, config.FREEZE_EPOCHS + 1):
train_class = train.Train(train_images, val_images, positive_image_dict, base_model, criterion, optimizer, epoch)
# alternatively training batch hard
if(epoch % 2 == 0):
train_class.train(batch_hard=False)
loss = train_class.validate(batch_hard=False)
else:
train_class.train(batch_hard=True)
loss = train_class.validate(batch_hard=True)
# remember best loss and save checkpoint
is_best = loss < BEST_LOSS
BEST_LOSS = min(loss, BEST_LOSS)
util.save_checkpoint({
'current_freeze_epoch': epoch,
'current_unfreeze_epoch': 0,
'state_dict': tnet.state_dict(),
'best_loss': BEST_LOSS,
}, is_best)
CURRENT_FREEZE_EPOCH = epoch
# visualizing the similar image outputs
inference.inference()
# Unfreezing the last few convolution layers
for param in base_model.parameters():
param.requires_grad = True
ct = 0
for name, child in base_model.named_children():
ct += 1
if ct < 7:
for name2, parameters in child.named_parameters():
parameters.requires_grad = False
# training the remaining iterations with the last few layers unfrozen
if(CURRENT_UNFREEZE_EPOCH < config.UNFREEZE_EPOCHS):
for epoch in range(CURRENT_UNFREEZE_EPOCH + 1, config.UNFREEZE_EPOCHS + 1):
train_class = train.Train(train_images, val_images, positive_image_dict, base_model, criterion, optimizer, epoch)
# alternatively training batch hard
if(epoch % 2 == 0):
train_class.train(batch_hard=False)
loss = train_class.validate(batch_hard=False)
else:
train_class.train(batch_hard=True)
loss = train_class.validate(batch_hard=True)
# remember best loss and save checkpoint
is_best = loss < BEST_LOSS
BEST_ACC = min(loss, BEST_LOSS)
util.save_checkpoint({
'current_freeze_epoch': CURRENT_FREEZE_EPOCH,
'current_unfreeze_epoch': epoch,
'state_dict': tnet.state_dict(),
'best_loss': BEST_LOSS,
}, is_best)
CURRENT_UNFREEZE_EPOCH = epoch
# visualizing the similar image outputs
inference.inference(epoch)