-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_CULane.py
112 lines (92 loc) · 3.52 KB
/
test_CULane.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import json
import os
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import dataset
from config import *
from model import SCNN
from utils.prob2lines import getLane
from utils.transforms import *
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--exp_dir", type=str, default="./experiments/exp10")
args = parser.parse_args()
return args
# ------------ config ------------
args = parse_args()
exp_dir = args.exp_dir
exp_name = exp_dir.split('/')[-1]
with open(os.path.join(exp_dir, "cfg.json")) as f:
exp_cfg = json.load(f)
resize_shape = tuple(exp_cfg['dataset']['resize_shape'])
device = torch.device('cuda')
def split_path(path):
"""split path tree into list"""
folders = []
while True:
path, folder = os.path.split(path)
if folder != "":
folders.insert(0, folder)
else:
if path != "":
folders.insert(0, path)
break
return folders
# ------------ data and model ------------
# # CULane mean, std
# mean=(0.3598, 0.3653, 0.3662)
# std=(0.2573, 0.2663, 0.2756)
# Imagenet mean, std
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
dataset_name = exp_cfg['dataset'].pop('dataset_name')
Dataset_Type = getattr(dataset, dataset_name)
transform = Compose(Resize(resize_shape), ToTensor(),
Normalize(mean=mean, std=std))
test_dataset = Dataset_Type(Dataset_Path[dataset_name], "test", transform)
test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=test_dataset.collate, num_workers=4)
net = SCNN(resize_shape, pretrained=False)
save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth')
save_dict = torch.load(save_name, map_location='cpu')
print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch'])
net.load_state_dict(save_dict['net'])
net = torch.nn.DataParallel(net.to(device))
net.eval()
# ------------ test ------------
out_path = os.path.join(exp_dir, "coord_output")
evaluation_path = os.path.join(exp_dir, "evaluate")
if not os.path.exists(out_path):
os.mkdir(out_path)
if not os.path.exists(evaluation_path):
os.mkdir(evaluation_path)
progressbar = tqdm(range(len(test_loader)))
with torch.no_grad():
for batch_idx, sample in enumerate(test_loader):
img = sample['img'].to(device)
img_name = sample['img_name']
seg_pred, exist_pred = net(img)[:2]
seg_pred = F.softmax(seg_pred, dim=1)
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
for b in range(len(seg_pred)):
seg = seg_pred[b]
exist = [1 if exist_pred[b, i] > 0.5 else 0 for i in range(4)]
lane_coords = getLane.prob2lines_CULane(seg, exist, resize_shape=(590, 1640), y_px_gap=20, pts=18)
path_tree = split_path(img_name[b])
save_dir, save_name = path_tree[-3:-1], path_tree[-1]
save_dir = os.path.join(out_path, *save_dir)
save_name = save_name[:-3] + "lines.txt"
save_name = os.path.join(save_dir, save_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with open(save_name, "w") as f:
for l in lane_coords:
for (x, y) in l:
print("{} {}".format(x, y), end=" ", file=f)
print(file=f)
progressbar.update(1)
progressbar.close()
# ---- evaluate ----
os.system("sh utils/lane_evaluation/CULane/Run.sh " + exp_name)