-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathtrain.py
55 lines (43 loc) · 1.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import data
import models
import optimizers
from options import TrainOptions
from util import IterationCounter
from util import Visualizer
from util import MetricTracker
from evaluation import GroupEvaluator
opt = TrainOptions().parse()
dataset = data.create_dataset(opt)
opt.dataset = dataset
iter_counter = IterationCounter(opt)
visualizer = Visualizer(opt)
metric_tracker = MetricTracker(opt)
evaluators = GroupEvaluator(opt)
model = models.create_model(opt)
optimizer = optimizers.create_optimizer(opt, model)
while not iter_counter.completed_training():
with iter_counter.time_measurement("data"):
cur_data = next(dataset)
with iter_counter.time_measurement("train"):
losses = optimizer.train_one_step(cur_data, iter_counter.steps_so_far)
metric_tracker.update_metrics(losses, smoothe=True)
with iter_counter.time_measurement("maintenance"):
if iter_counter.needs_printing():
visualizer.print_current_losses(iter_counter.steps_so_far,
iter_counter.time_measurements,
metric_tracker.current_metrics())
if iter_counter.needs_displaying():
visuals = optimizer.get_visuals_for_snapshot(cur_data)
visualizer.display_current_results(visuals,
iter_counter.steps_so_far)
if iter_counter.needs_evaluation():
metrics = evaluators.evaluate(
model, dataset, iter_counter.steps_so_far)
metric_tracker.update_metrics(metrics, smoothe=False)
if iter_counter.needs_saving():
optimizer.save(iter_counter.steps_so_far)
if iter_counter.completed_training():
break
iter_counter.record_one_iteration()
optimizer.save(iter_counter.steps_so_far)
print('Training finished.')