-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
120 lines (106 loc) · 4.53 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#-*- coding:utf-8 -*-
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import warnings
warnings.filterwarnings('ignore')
import torch
import pickle
import argparse
from utils.timer import Timer
import torch.backends.cudnn as cudnn
from layers.functions import Detect, PriorBox
from data import BaseTransform
from configs.CC import Config
from dsod import build_net
from tqdm import tqdm
from utils.core import *
parser = argparse.ArgumentParser(description='STDN Evaluation')
parser.add_argument(
'-c', '--config', default='configs/DSOD300-64-192-48-1.py', type=str)
parser.add_argument('-d', '--dataset', default='COCO',
help='VOC or COCO version')
parser.add_argument('-m', '--trained_model', default=None,
type=str, help='Trained state_dict file path to open')
parser.add_argument('--test', action='store_true',
help='to submit a test file')
args = parser.parse_args()
print_info('----------------------------------------------------------------------\n'
'| STDN Evaluation Program |\n'
'----------------------------------------------------------------------', ['yellow', 'bold'])
global cfg
cfg = Config.fromfile(args.config)
cfg.model.num_classes = 21 if args.dataset == 'VOC' else 81
if not os.path.exists(cfg.test_cfg.save_folder):
os.mkdir(cfg.test_cfg.save_folder)
anchor_config = anchors(cfg.model)
print_info('The Anchor info: \n{}'.format(anchor_config))
priorbox = PriorBox(anchor_config)
with torch.no_grad():
priors = priorbox.forward()
if cfg.test_cfg.cuda:
priors = priors.cuda()
num_classes = cfg.model.num_classes
def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image=300, thresh=0.005):
if not os.path.exists(save_folder):
os.mkdir(save_folder)
num_images = len(testset)
print_info('=> Total {} images to test.'.format(
num_images), ['yellow', 'bold'])
all_boxes = [[[] for _ in range(num_images)] for _ in range(num_classes)]
_t = {'im_detect': Timer(), 'misc': Timer()}
det_file = os.path.join(save_folder, 'detections.pkl')
tot_detect_time, tot_nms_time = 0, 0
print_info('Begin to evaluate', ['yellow', 'bold'])
for i in tqdm(range(num_images)):
img = testset.pull_image(i)
# step1: CNN detection
_t['im_detect'].tic()
boxes, scores = image_forward(
img, net, cuda, priors, detector, transform)
detect_time = _t['im_detect'].toc()
# step2: Post-process: NMS
_t['misc'].tic()
nms_process(num_classes, i, scores, boxes, cfg,
thresh, all_boxes, max_per_image)
nms_time = _t['misc'].toc()
tot_detect_time += detect_time if i > 0 else 0
tot_nms_time += nms_time if i > 0 else 0
with open(det_file, 'wb') as f:
pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
print_info('===> Evaluating detections', ['yellow', 'bold'])
testset.evaluate_detections(all_boxes, save_folder)
print_info('Detect time per image: {:.3f}s'.format(
tot_detect_time / (num_images - 1)))
print_info('Nms time per image: {:.3f}s'.format(
tot_nms_time / (num_images - 1)))
print_info('Total time per image: {:.3f}s'.format(
(tot_detect_time + tot_nms_time) / (num_images - 1)))
print_info('FPS: {:.3f} fps'.format(
(num_images - 1) / (tot_detect_time + tot_nms_time)))
if __name__ == '__main__':
net = build_net('test', cfg.model.input_size, cfg.model)
init_net(net, cfg, args.trained_model)
print_info('===> Finished constructing and loading model',
['yellow', 'bold'])
net.eval()
_set = 'eval_sets' if not args.test else 'test_sets'
testset = get_dataloader(cfg, args.dataset, _set)
if cfg.test_cfg.cuda:
net = net.cuda()
cudnn.benckmark = True
else:
net = net.cpu()
detector = Detect(num_classes, cfg.loss.bkg_label, anchor_config)
save_folder = os.path.join(cfg.test_cfg.save_folder, args.dataset)
_preprocess = BaseTransform(
cfg.model.input_size, cfg.model.rgb_means, (2, 0, 1))
test_net(save_folder,
net,
detector,
cfg.test_cfg.cuda,
testset,
transform=_preprocess,
max_per_image=cfg.test_cfg.topk,
thresh=cfg.test_cfg.score_threshold)