-
Notifications
You must be signed in to change notification settings - Fork 24
/
helpers_train_test.py
128 lines (93 loc) · 4.01 KB
/
helpers_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import time, argparse
from datetime import datetime
import numpy as np
from tensorboardX import SummaryWriter
from quaternions import *
import tqdm
#Generic training function
def train(model, loss_fn, optimizer, x, q_gt):
# Reset gradient
optimizer.zero_grad()
# Forward
q_est = model.forward(x)
loss = loss_fn(q_est, q_gt)
# Backward
loss.backward()
# Update parameters
optimizer.step()
return (q_est, loss.item())
def test(model, loss_fn, x, q_gt):
# Forward
with torch.no_grad():
q_est = model.forward(x)
loss = loss_fn(q_est, q_gt)
return (q_est, loss.item())
def train_test_model(args, loss_fn, model, train_loader, test_loader, tensorboard_output=True, progress_bar=True, scheduler=False):
if tensorboard_output:
writer = SummaryWriter()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
if scheduler:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.2)
#Save stats
train_stats = torch.zeros(args.epochs, 2)
test_stats = torch.zeros(args.epochs, 2)
device = next(model.parameters()).device
tensor_type = torch.double if args.double else torch.float
rotmat_targets = train_loader.dataset.rotmat_targets
for e in range(args.epochs):
start_time = time.time()
#Train model
model.train()
train_loss = torch.tensor(0.)
train_mean_err = torch.tensor(0.)
num_train_batches = len(train_loader)
if progress_bar:
pbar = tqdm.tqdm(total=num_train_batches)
for _, (x, target) in enumerate(train_loader):
#Move all data to appropriate device
target = target.to(device=device, dtype=tensor_type)
x = x.to(device=device, dtype=tensor_type)
(rot_est, train_loss_k) = train(model, loss_fn, optimizer, x, target)
if rotmat_targets:
train_mean_err += (1./num_train_batches)*rotmat_angle_diff(rot_est, target)
else:
train_mean_err += (1./num_train_batches)*quat_angle_diff(rot_est, target)
train_loss += (1./num_train_batches)*train_loss_k
if progress_bar:
pbar.update(1)
if progress_bar:
pbar.close()
#Test model
model.eval()
num_test_batches = len(test_loader)
test_loss = torch.tensor(0.)
test_mean_err = torch.tensor(0.)
for _, (x, target) in enumerate(test_loader):
#Move all data to appropriate device
target = target.to(device=device, dtype=tensor_type)
x = x.to(device=device, dtype=tensor_type)
(rot_est, test_loss_k) = test(model, loss_fn, x, target)
if rotmat_targets:
test_mean_err += (1./num_test_batches)*rotmat_angle_diff(rot_est, target)
else:
test_mean_err += (1./num_test_batches)*quat_angle_diff(rot_est, target)
test_loss += (1./num_test_batches)*test_loss_k
test_stats[e, 0] = test_loss
test_stats[e, 1] = test_mean_err
if tensorboard_output:
writer.add_scalar('validation/loss', test_loss, e)
writer.add_scalar('validation/mean_err', test_mean_err, e)
writer.add_scalar('training/loss', train_loss, e)
writer.add_scalar('training/mean_err', train_mean_err, e)
#History tracking
train_stats[e, 0] = train_loss
train_stats[e, 1] = train_mean_err
elapsed_time = time.time() - start_time
output_string = 'Epoch: {}/{}. Train: Loss {:.3E} / Error {:.3f} (deg) | Test: Loss {:.3E} / Error {:.3f} (deg). Epoch time: {:.3f} sec.'.format(e+1, args.epochs, train_loss, train_mean_err, test_loss, test_mean_err, elapsed_time)
print(output_string)
if scheduler:
scheduler.step()
if tensorboard_output:
writer.close()
return train_stats, test_stats