-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_framework.py
86 lines (69 loc) · 2.7 KB
/
train_framework.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable as V
class TrainFramework():
def __init__(self, net, loss, lr=3e-4, evalmode=False, num_classes=7):
self.net = net(num_classes=num_classes).cuda()
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr)
self.loss = loss
self.old_lr = lr
if evalmode:
for i in self.net.modules():
if isinstance(i, nn.BatchNorm2d):
i.eval()
def set_input(self, img_batch, mask_batch=None, img_id=None):
self.img = img_batch
self.mask = mask_batch
self.img_id = img_id
def test_one_img(self, img):
pred = self.net.forward(img)
# pred[pred > 0.5] = 1
# pred[pred <= 0.5] = 0
pred = torch.argmax(pred, dim=1)
mask = pred.squeeze().cpu().data.numpy()
return mask
def test_batch(self):
self.forward(volatile=True)
mask = self.net.forward(self.img).cpu().data.numpy().squeeze(1)
# mask[mask > 0.5] = 1
# mask[mask <= 0.5] = 0
mask = torch.argmax(mask, dim=1)
return mask, self.img_id
def test_one_img_from_path(self, path):
img = cv2.imread(path)
img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
img = V(torch.Tensor(img).cuda())
mask = self.net.forward(img).squeeze().cpu().data.numpy()
# mask[mask > 0.5] = 1
# mask[mask <= 0.5] = 0
mask = torch.argmax(mask, dim=1)
return mask
def forward(self, volatile=False):
self.img = V(self.img.cuda(), volatile=volatile)
if self.mask is not None:
self.mask = V(self.mask.cuda(), volatile=volatile)
def optimize(self):
self.forward()
self.optimizer.zero_grad()
pred = self.net.forward(self.img)
loss = self.loss(pred, self.mask)
loss.backward()
self.optimizer.step()
# return loss.data[0]
return loss.item()
def save(self, path):
torch.save(self.net.state_dict(), path)
def load(self, path):
self.net.load_state_dict(torch.load(path))
def update_lr(self, new_lr, mylog, factor=False):
if factor:
new_lr = self.old_lr / new_lr
new_lr = max([new_lr, 0.00002])
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
mylog.write(str(mylog) + 'update learning rate: %f -> %f' % (self.old_lr, new_lr))
mylog.write(str(mylog) + 'update learning rate: %f -> %f' % (self.old_lr, new_lr))
self.old_lr = new_lr