-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
134 lines (113 loc) · 5.21 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.nn.utils import clip_grad_norm_
from utils.meters import AverageMeter, accuracy
from img_utils import iou, pixel_accuracy
import matplotlib.pyplot as plt
class Trainer(object):
def __init__(self, model, criterion, optimizer=None,
device_ids=0, device=torch.cuda, dtype=torch.float,
distributed=False, local_rank=-1,
grad_clip=-1, print_freq=100):
self._model = model
self.criterion = criterion
self.epoch = 0
self.training_steps = 0
self.optimizer = optimizer
self.device = device
self.dtype = dtype
self.local_rank = local_rank
self.print_freq = print_freq
self.grad_clip = grad_clip
def empty_reg(m): return 0
self.regularizer = getattr(model, 'regularization', empty_reg)
self.regularizer_pre_step = getattr(
model, 'regularization_pre_step', empty_reg)
self.regularizer_post_step = getattr(
model, 'regularization_post_step', empty_reg)
if distributed:
self.model = nn.parallel.DistributedDataParallel(model,
device_ids=[
local_rank],
output_device=local_rank)
elif device_ids and len(device_ids) > 1:
self.model = nn.DataParallel(model, device_ids)
else:
self.model = model
def _step(self, inputs, target, training=False):
# compute output
output = self.model(inputs)
loss = self.criterion(output, target)
loss += self.regularizer(self.model)
grad = None
if isinstance(output, list) or isinstance(output, tuple):
output = output[0]
if training:
self.optimizer.update(self.epoch, self.training_steps)
# compute gradient and do SGD step
self.optimizer.zero_grad()
loss.backward()
self.regularizer_pre_step(self.model)
if self.grad_clip > 0:
grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
self.regularizer_post_step(self.model)
self.training_steps += 1
return output, loss, grad
def forward(self, data_loader, num_steps=None, training=False):
meters = {name: AverageMeter()
for name in ['step', 'data', 'loss', 'iou', 'pixel_accuracy']}
if training and self.grad_clip > 0:
meters['grad'] = AverageMeter()
def meter_results(meters):
results = {name: meter.avg for name, meter in meters.items()}
return results
end = time.time()
for i, (inputs, target) in enumerate(data_loader):
# measure data loading time
meters['data'].update(time.time() - end)
target = target.to(self.device)
inputs = inputs.to(self.device, dtype=self.dtype)
output, loss, grad = self._step(inputs, target, training=training)
# measure accuracy and record loss
meters['loss'].update(float(loss), inputs.size(0))
meters['iou'].update(iou(output, target, n_classes=24))
meters['pixel_accuracy'].update(pixel_accuracy(output, target))
if grad is not None:
meters['grad'].update(float(grad), inputs.size(0))
# measure elapsed time
meters['step'].update(time.time() - end)
end = time.time()
if i % self.print_freq == 0:
report = str('{phase} - Epoch: [{0}][{1}/{2}]\t'
'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t'
'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t'
'Loss {meters[loss].val:.4f} ({meters[loss].avg:.4f})\t'
'pixel_accuracy {meters[pixel_accuracy].val:.3f} ({meters[pixel_accuracy].avg:.3f})\t'
'iou {meters[iou].val:.3f} ({meters[iou].avg:.3f})\t'
.format(
self.epoch, i, len(data_loader),
phase='TRAINING' if training else 'EVALUATING',
meters=meters))
if 'grad' in meters.keys():
report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})'\
.format(meters=meters)
logging.info(report)
if num_steps is not None and i >= num_steps:
break
# if i % 200 == 0:
# plt.imshow(target[0]); plt.show()
# plt.imshow(output[0].max(dim=0)[1]); plt.show()
return meter_results(meters)
def train(self, data_loader):
# switch to train mode
self.model.train()
return self.forward(data_loader, training=True)
def validate(self, data_loader):
# switch to evaluate mode
self.model.eval()
with torch.no_grad():
return self.forward(data_loader, training=False)