-
Notifications
You must be signed in to change notification settings - Fork 9
/
recognize.py
132 lines (105 loc) · 3.47 KB
/
recognize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
import utils
import matching
from models import detection, description
FLAGS = None
def detect_pores(imgs):
with tf.Graph().as_default():
# placeholder for image
image_pl, _ = utils.placeholder_inputs()
# build detection net
print('Building detection net graph...')
det_net = detection.Net(image_pl, training=False)
print('Done')
with tf.Session() as sess:
print('Restoring detection model in {}...'.format(FLAGS.det_model_dir))
utils.restore_model(sess, FLAGS.det_model_dir)
print('Done')
# capture detection arguments in function
def single_detect_pores(image):
return utils.detect_pores(
image, image_pl, det_net.predictions, FLAGS.det_patch_size // 2,
FLAGS.det_prob_thr, FLAGS.nms_inter_thr, sess)
# detect pores
dets = [single_detect_pores(img) for img in imgs]
return dets
def describe_detections(imgs, dets):
with tf.Graph().as_default():
# placeholder for image
image_pl, _ = utils.placeholder_inputs()
# build description net
print('Building description net graph...')
desc_net = description.Net(image_pl, training=False)
print('Done')
with tf.Session() as sess:
print('Restoring description model in {}...'.format(
FLAGS.desc_model_dir))
utils.restore_model(sess, FLAGS.desc_model_dir)
print('Done')
# capture description arguments in function
def compute_descriptors(image, dets):
return utils.trained_descriptors(image, dets, FLAGS.desc_patch_size,
sess, image_pl, desc_net.descriptors)
# compute descriptors
descs = []
new_dets = []
for img, img_dets in zip(imgs, dets):
img_descs, img_new_dets = compute_descriptors(img, img_dets)
descs.append(img_descs)
new_dets.append(img_new_dets)
return descs, new_dets
def main():
# load images
imgs = [utils.load_image(path) for path in FLAGS.img_paths]
dets = detect_pores(imgs)
tf.reset_default_graph()
descs, dets = describe_detections(imgs, dets)
score = matching.basic(descs[0], descs[1], thr=0.7)
print('similarity score = {}'.format(score))
if score > FLAGS.score_thr:
print('genuine pair')
else:
print('impostor pair')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--img_paths',
required=True,
type=str,
nargs=2,
help='path to images to be recognized')
parser.add_argument(
'--det_model_dir',
required=True,
type=str,
help='path to pore detection trained model')
parser.add_argument(
'--desc_model_dir',
required=True,
type=str,
help='path to pore description trained model')
parser.add_argument(
'--score_thr',
default=2,
type=int,
help='score threshold to determine if pair is genuine or impostor')
parser.add_argument(
'--det_patch_size', default=17, type=int, help='detection patch size')
parser.add_argument(
'--det_prob_thr',
default=0.9,
type=float,
help='probability threshold for discarding detections')
parser.add_argument(
'--nms_inter_thr',
default=0.1,
type=float,
help='NMS area intersection threshold')
parser.add_argument(
'--desc_patch_size',
default=32,
type=int,
help='patch size around each detected keypoint to describe')
FLAGS = parser.parse_args()
main()