forked from r3krut/KITTI_ROAD_SEGMENTATION
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_a2d2.py
228 lines (195 loc) · 10.5 KB
/
main_a2d2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
The Main module
"""
import cv2
import sys
import argparse
import logging
import numpy as np
from pathlib import Path
import pickle
import utils.utils as utils
import utils.img_utils as imutils
from models.reknetm1 import RekNetM1
from models.reknetm2 import RekNetM2
from models.lidcamnet_fcn import LidCamNetEarlyFusion, LidCamNetLateFusion, LidCamNet
from data_processing.road_dataset import a2d2_dataset, a2d2_dataset_no_lidar, \
a2d2_ip_input_file, a2d2_upsample_input_file, a2d2_output_file
from data_processing.data_processing import crossval_split_a2d2
from misc.losses import BCEJaccardLoss, CCEJaccardLoss
from misc.polylr_scheduler import PolyLR
from misc.transforms import (
train_transformations_a2d2,
transform_normalize_img,
valid_transformations_a2d2,
transform_normalize_lidar
)
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.backends.cudnn
# For reproducibility
# torch.manual_seed(111)
def main(*args, **kwargs):
parser = argparse.ArgumentParser(description="Argument parser for the main module. Main module represents train procedure.")
parser.add_argument("--root-dir", type=str, required=True, help="Path to the root dir where will be stores models.")
parser.add_argument("--dataset-path", type=str, required=True, help="Path to the a2d2 dataset dir which contains pickle files")
parser.add_argument("--fold", type=int, default=1, help="Num of a validation fold.")
# optimizer options
parser.add_argument("--optim", type=str, default="SGD", help="Type of optimizer: SGD or Adam")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rates for optimizer.")
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum for SGD optim.")
# Scheduler options
parser.add_argument("--scheduler", type=str, default="poly", help="Type of a scheduler for LR scheduling.")
parser.add_argument("--step-st", type=int, default=5, help="Step size for StepLR schedule.")
parser.add_argument("--milestones", type=str, default="30,70,90", help="List with milestones for MultiStepLR schedule.")
parser.add_argument("--gamma", type=float, default=0.1, help="Gamma parameter for StepLR and MultiStepLR schedule.")
parser.add_argument("--patience", type=int, default=5, help="Patience parameter for ReduceLROnPlateau schedule.")
# model params
parser.add_argument("--model-type", type=str, default="lcn_early", help="Type of model. Can be 'lcn_late' and 'lcn_early'.")
parser.add_argument("--init-type", type=str, default="He", help="Initialization type. Can be 'He' or 'Xavier'.")
parser.add_argument("--act-type", type=str, default="relu", help="Activation type. Can be ReLU, CELU or FTSwish+.")
parser.add_argument("--enc-bn-enable", type=int, default=1, help="Batch normalization enabling in encoder module.")
parser.add_argument("--dec-bn-enable", type=int, default=1, help="Batch normalization enabling in decoder module.")
parser.add_argument("--skip-conn", type=int, default=0, help="Skip-connection in context module.")
# other options
parser.add_argument("--n-epochs", type=int, default=100, help="Number of training epochs.")
parser.add_argument("--batch-size", type=int, default=2, help="Number of examples per batch.")
parser.add_argument("--batch-factor", type=int, default=8, help="Number of examples per batch.")
parser.add_argument("--num-workers", type=int, default=8, help="Number of loading workers.")
parser.add_argument("--device-ids", type=str, default="0", help="ID of devices for multiple GPUs.")
parser.add_argument("--alpha", type=float, default=0, help="Modulation factor for custom loss.")
parser.add_argument("--status-every", type=int, default=1, help="Status every parameter.")
args = parser.parse_args()
#Console logger definition
console_logger = logging.getLogger("console-logger")
console_logger.setLevel(logging.INFO)
ch = logging.StreamHandler(stream=sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
console_logger.addHandler(ch)
console_logger.info(args)
# number of classes
num_classes = 1
# Model definition
if args.model_type == "lcn_early":
model = LidCamNetEarlyFusion(num_classes=num_classes, bn_enable=True)
console_logger.info("Using LinCamNet-early as the model.")
elif args.model_type == "lcn_late":
model = LidCamNetLateFusion(num_classes=num_classes, bn_enable=True)
console_logger.info("Using LinCamNet-late as the model.")
elif args.model_type == "lcn":
model = LidCamNet(num_classes=num_classes, bn_enable=True)
console_logger.info("Using LinCamNet as the model.")
else:
raise ValueError("Unknown model type: {}".format(args.model_type))
console_logger.info("Number of trainable parameters: {}".format(utils.count_params(model)[1]))
# Move model to devices
if torch.cuda.is_available():
if args.device_ids:
device_ids = list(map(int, args.device_ids.split(',')))
else:
device_ids = None
model = nn.DataParallel(model, device_ids=device_ids).cuda()
cudnn.benchmark = True
# Loss definition
# TODO: Find a reason for using this loss.
loss = BCEJaccardLoss(alpha=args.alpha)
dataset_path = Path(args.dataset_path)
images = utils.read_pickle_file(str(dataset_path / a2d2_ip_input_file))
masks = utils.read_pickle_file(str(dataset_path / a2d2_output_file))
# Use data subset
#images = images[:20]
#masks = masks[:20]
images = images[:-round(len(images)/4)]
masks = masks[:-round(len(masks)/4)]
# train-val splits for cross-validation by a fold
((train_imgs, train_masks),
(valid_imgs, valid_masks)) = crossval_split_a2d2(imgs_paths=images, masks_paths=masks, fold=args.fold)
# Define training/validation/ dataset
if args.model_type == "lcn":
train_dataset = a2d2_dataset_no_lidar(img_paths=train_imgs, mask_paths=train_masks, \
transform_image=train_transformations_a2d2(), \
normalize_image=transform_normalize_lidar())
valid_dataset = a2d2_dataset_no_lidar(img_paths=valid_imgs, mask_paths=valid_masks, \
transform_image=valid_transformations_a2d2(), \
normalize_image=transform_normalize_lidar())
else:
train_dataset = a2d2_dataset(img_paths=train_imgs, mask_paths=train_masks, \
transform_image=train_transformations_a2d2(), \
normalize_image=transform_normalize_lidar(), \
normalize_lidar=transform_normalize_lidar())
valid_dataset = a2d2_dataset(img_paths=valid_imgs, mask_paths=valid_masks, \
transform_image=valid_transformations_a2d2(), \
normalize_image=transform_normalize_lidar(), \
normalize_lidar=transform_normalize_lidar())
# valid_fmeasure_datset = a2d2_dataset(img_paths=valid_imgs, mask_paths=valid_masks)
# Create Data Loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=torch.cuda.is_available())
valid_loader = DataLoader(dataset=valid_dataset, batch_size=1, num_workers=args.num_workers, pin_memory=torch.cuda.is_available())
console_logger.info("Train dataset length: {}".format(len(train_dataset)))
console_logger.info("Validation dataset length: {}".format(len(valid_dataset)))
# Optim definition
if args.optim == "SGD":
optim = SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum)
console_logger.info("Uses the SGD optimizer with initial lr={0} and momentum={1}".format(args.lr, args.momentum))
else:
optim = Adam(params=model.parameters(), lr=args.lr)
console_logger.info("Uses the Adam optimizer with initial lr={0}".format(args.lr))
if args.scheduler == "step":
lr_scheduler = StepLR(optimizer=optim, step_size=args.step_st, gamma=args.gamma)
console_logger.info("Uses the StepLR scheduler with step={} and gamma={}.".format(args.step_st, args.gamma))
elif args.scheduler == "multi-step":
lr_scheduler = MultiStepLR(optimizer=optim, milestones=[int(m) for m in (args.milestones).split(",")], gamma=args.gamma)
console_logger.info("Uses the MultiStepLR scheduler with milestones=[{}] and gamma={}.".format(args.milestones, args.gamma))
elif args.scheduler == "rlr-plat":
lr_scheduler = ReduceLROnPlateau(optimizer=optim, patience=args.patience, verbose=True)
console_logger.info("Uses the ReduceLROnPlateau scheduler.")
elif args.scheduler == "poly":
lr_scheduler = PolyLR(optimizer=optim, num_epochs=args.n_epochs, alpha=args.gamma)
console_logger.info("Uses the PolyLR scheduler.")
else:
raise ValueError("Unknown type of schedule: {}".format(args.scheduler))
if not args.model_type == "lcn":
valid = utils.binary_validation_routine_a2d2
utils.train_routine_a2d2(
args=args,
console_logger=console_logger,
root=args.root_dir,
model=model,
criterion=loss,
optimizer=optim,
scheduler=lr_scheduler,
train_loader=train_loader,
valid_loader=valid_loader,
fm_eval_dataset=None,
validation=valid,
fold=args.fold,
num_classes=num_classes,
n_epochs=args.n_epochs,
status_every=args.status_every
)
else:
valid = utils.binary_validation_routine_a2d2_no_lidar
utils.train_routine_a2d2_no_lidar(
args=args,
console_logger=console_logger,
root=args.root_dir,
model=model,
criterion=loss,
optimizer=optim,
scheduler=lr_scheduler,
train_loader=train_loader,
valid_loader=valid_loader,
fm_eval_dataset=None,
validation=valid,
fold=args.fold,
num_classes=num_classes,
n_epochs=args.n_epochs,
status_every=args.status_every
)
if __name__ == "__main__":
main()