forked from ozan-oktay/Attention-Gated-Networks
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_segmentation.py
107 lines (81 loc) · 4.25 KB
/
train_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataio.loader import get_dataset, get_dataset_path
from dataio.transformation import get_dataset_transformation
from utils.util import json_file_to_pyobj
from utils.visualiser import Visualiser
from utils.error_logger import ErrorLogger
from models import get_model
def train(arguments):
# Parse input arguments
json_filename = arguments.config
network_debug = arguments.debug
# Load options
json_opts = json_file_to_pyobj(json_filename)
train_opts = json_opts.training
# Architecture type
arch_type = train_opts.arch_type
# Setup Dataset and Augmentation
ds_class = get_dataset(arch_type)
ds_path = get_dataset_path(arch_type, json_opts.data_path)
ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)
# Setup the NN Model
model = get_model(json_opts.model)
if network_debug:
print('# of pars: ', model.get_number_parameters())
print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time()))
exit()
# Setup Data Loader
train_dataset = ds_class(ds_path, split='train', transform=ds_transform['train'], preload_data=train_opts.preloadData)
valid_dataset = ds_class(ds_path, split='validation', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
train_loader = DataLoader(dataset=train_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False)
# Visualisation Parameters
visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
error_logger = ErrorLogger()
# Training Function
model.set_scheduler(train_opts)
for epoch in range(model.which_epoch, train_opts.n_epochs):
print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
# Training Iterations
for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1), total=len(train_loader)):
# Make a training update
model.set_input(images, labels)
model.optimize_parameters()
#model.optimize_parameters_accumulate_grd(epoch_iter)
# Error visualisation
errors = model.get_current_errors()
error_logger.update(errors, split='train')
# Validation and Testing Iterations
for loader, split in zip([valid_loader, test_loader], ['validation', 'test']):
for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)):
# Make a forward pass with the model
model.set_input(images, labels)
model.validate()
# Error visualisation
errors = model.get_current_errors()
stats = model.get_segmentation_stats()
error_logger.update({**errors, **stats}, split=split)
# Visualise predictions
visuals = model.get_current_visuals()
visualizer.display_current_results(visuals, epoch=epoch, save_result=False)
# Update the plots
for split in ['train', 'validation', 'test']:
visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split)
visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split)
error_logger.reset()
# Save the model parameters
if epoch % train_opts.save_epoch_freq == 0:
model.save(epoch)
# Update the model learning rate
model.update_learning_rate()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='CNN Seg Training Function')
parser.add_argument('-c', '--config', help='training config file', required=True)
parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true')
args = parser.parse_args()
train(args)