-
Notifications
You must be signed in to change notification settings - Fork 94
/
resa50_culane.py
116 lines (99 loc) · 2.12 KB
/
resa50_culane.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
net = dict(
type='Detector',
)
backbone = dict(
type='ResNetWrapper',
resnet='resnet50',
pretrained=True,
replace_stride_with_dilation=[False, True, True],
out_conv=True,
in_channels=[64, 128, 256, -1]
)
featuremap_out_channel = 128
featuremap_out_stride = 8
sample_y = range(589, 230, -20)
aggregator = dict(
type='RESA',
direction=['d', 'u', 'r', 'l'],
alpha=2.0,
iter=4,
conv_stride=9,
)
heads = dict(
type='LaneSeg',
decoder=dict(type='PlainDecoder'),
exist=dict(type='ExistHead'),
thr=0.3,
sample_y=sample_y,
)
trainer = dict(
type='RESA'
)
evaluator = dict(
type='CULane',
)
optimizer = dict(
type = 'SGD',
lr = 0.030,
weight_decay = 1e-4,
momentum = 0.9
)
epochs = 12
batch_size = 8
total_iter = (88880 // batch_size) * epochs
import math
scheduler = dict(
type = 'LambdaLR',
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)
seg_loss_weight = 1.0
eval_ep = 6
save_ep = epochs
bg_weight = 0.4
img_norm = dict(
mean=[103.939, 116.779, 123.68],
std=[1., 1., 1.]
)
img_height = 288
img_width = 800
cut_height = 240
ori_img_h = 590
ori_img_w = 1640
train_process = [
dict(type='RandomRotation', degree=(-2, 2)),
dict(type='RandomHorizontalFlip'),
dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img', 'mask', 'lane_exist']),
]
val_process = [
dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img']),
]
dataset_path = './data/CULane'
dataset = dict(
train=dict(
type='CULane',
data_root=dataset_path,
split='train',
processes=train_process,
),
val=dict(
type='CULane',
data_root=dataset_path,
split='test',
processes=val_process,
),
test=dict(
type='CULane',
data_root=dataset_path,
split='test',
processes=val_process,
)
)
workers = 12
num_classes = 4 + 1
ignore_label = 255
log_interval = 1000
lr_update_by_epoch = False