-
Notifications
You must be signed in to change notification settings - Fork 7
/
predict_det.py
147 lines (135 loc) · 5.42 KB
/
predict_det.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import print_function, division
import os
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT']='0'
import multiprocessing as mp
import argparse
import cv2
import mxnet as mx
import numpy as np
import pandas as pd
from lib.config import cfg
from lib.model import DetNet, load_model, multi_scale_predict, multi_scale_detection
from lib.utils import draw_heatmap, draw_paf, draw_kps, draw_box
from lib.utils import get_logger, crop_patch
from lib.detect_kps import detect_kps
from lib.rpn import AnchorProposal
file_pattern = './result/tmp_%s_result_%d.csv'
def get_border(bbox, w, h, expand=0.1):
xmin, ymin, xmax, ymax = bbox
bh, bw = ymax - ymin, xmax - xmin
xmin -= expand * bw
xmax += expand * bw
ymin -= expand * bh
ymax += expand * bh
xmin = max(min(int(xmin), w), 0)
xmax = max(min(int(xmax), w), 0)
ymin = max(min(int(ymin), h), 0)
ymax = max(min(int(ymax), h), 0)
return (xmin, ymin, xmax, ymax)
def work_func(df, idx, args):
# hyper parameters
ctx = mx.cpu(0) if args.gpu == -1 else mx.gpu(args.gpu)
data_dir = args.data_dir
version = args.version
show = args.show
multi_scale = args.multi_scale
logger = get_logger()
# model
feat_stride = cfg.FEAT_STRIDE
scales = cfg.DET_SCALES
ratios = cfg.DET_RATIOS
anchor_proposals = [AnchorProposal(scales[i], ratios, feat_stride[i]) for i in range(2)]
detnet = DetNet(anchor_proposals)
creator, featname, fixed = cfg.BACKBONE_Det['resnet50']
detnet.init_backbone(creator, featname, fixed, pretrained=False)
detnet.load_params(args.det_model, ctx)
detnet.hybridize()
kpsnet = load_model(args.kps_model, version=version)
kpsnet.collect_params().reset_ctx(ctx)
kpsnet.hybridize()
# data
image_ids = df['image_id'].tolist()
image_paths = [os.path.join(data_dir, img_id) for img_id in image_ids]
image_categories = df['image_category'].tolist()
# run
result = []
for i, (path, category) in enumerate(zip(image_paths, image_categories)):
img = cv2.imread(path)
# detection
h, w = img.shape[:2]
dets = multi_scale_detection(detnet, ctx, img, category)
if len(dets) != 0:
bbox = dets[0, :4]
score = dets[0, -1]
else:
bbox = [0, 0, w, h]
score = 0
bbox = get_border(bbox, w, h, 0.2)
roi = crop_patch(img, bbox)
# predict kps
heatmap, paf = multi_scale_predict(kpsnet, ctx, roi, multi_scale)
kps_pred = detect_kps(roi, heatmap, paf, category)
x1, y1 = bbox[:2]
kps_pred[:, 0] += x1
kps_pred[:, 1] += y1
result.append(kps_pred)
# show
if show:
landmark_idx = cfg.LANDMARK_IDX[category]
heatmap = heatmap[landmark_idx].max(axis=0)
cv2.imshow('det', draw_box(img, bbox, '%s_%.2f' % (category, score)))
cv2.imshow('heatmap', draw_heatmap(roi, heatmap))
cv2.imshow('kps_pred', draw_kps(img, kps_pred))
cv2.imshow('paf', draw_paf(roi, paf))
key = cv2.waitKey(0)
if key == 27:
break
if i % 100 == 0:
logger.info('Worker %d process %d samples', idx, i + 1)
# save
fn = file_pattern % (args.type, idx)
with open(fn, 'w') as fout:
header = 'image_id,image_category,neckline_left,neckline_right,center_front,shoulder_left,shoulder_right,armpit_left,armpit_right,waistline_left,waistline_right,cuff_left_in,cuff_left_out,cuff_right_in,cuff_right_out,top_hem_left,top_hem_right,waistband_left,waistband_right,hemline_left,hemline_right,crotch,bottom_left_in,bottom_left_out,bottom_right_in,bottom_right_out\n'
fout.write(header)
for img_id, category, kps in zip(image_ids, image_categories, result):
fout.write(img_id)
fout.write(',%s'%category)
for p in kps:
s = ',%d_%d_%d' % (p[0], p[1], p[2])
fout.write(s)
fout.write('\n')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default='0')
parser.add_argument('--det-model', type=str, required=True)
parser.add_argument('--kps-model', type=str, required=True)
parser.add_argument('--version', type=int, default=2)
parser.add_argument('--show', action='store_true')
parser.add_argument('--multi-scale', action='store_true')
parser.add_argument('--num-worker', type=int, default=1)
parser.add_argument('--type', type=str, default='val', choices=['val', 'test'])
args = parser.parse_args()
print(args)
# data
if args.type == 'val':
data_dir = cfg.DATA_DIR
df = pd.read_csv(os.path.join(data_dir, 'val.csv'))
else:
data_dir = os.path.join(cfg.DATA_DIR, 'r2-test-a')
df = pd.read_csv(os.path.join(data_dir, 'test.csv'))
args.data_dir = data_dir
#df = df.sample(frac=1)
num_worker = args.num_worker
num_sample = len(df) // num_worker + 1
dfs = [df[i*num_sample: (i+1)*num_sample] for i in range(num_worker)]
# run
workers = [mp.Process(target=work_func, args=(dfs[i], i, args)) for i in range(num_worker)]
for worker in workers:
worker.start()
for worker in workers:
worker.join()
# merge
result = pd.concat([pd.read_csv(file_pattern % (args.type, i)) for i in range(num_worker)])
result.to_csv('./result/%s_result.csv' % args.type, index=False)
if __name__ == '__main__':
main()