-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_cuhk03.py
87 lines (73 loc) · 3.18 KB
/
eval_cuhk03.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from collections import defaultdict
import numpy as np
import scipy.io
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
"""Evaluation with cuhk03 metric
Key: one image for each gallery identity is randomly sampled for each query identity.
Random sampling is performed N times (market: N=100).
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue
kept_g_pids = g_pids[order][keep]
g_pids_dict = defaultdict(list)
for idx, pid in enumerate(kept_g_pids):
g_pids_dict[pid].append(idx)
cmc, AP = 0., 0.
for repeat_idx in range(N):
mask = np.zeros(len(orig_cmc), dtype=np.bool)
for _, idxs in g_pids_dict.items():
# randomly sample one image for each gallery person
rnd_idx = np.random.choice(idxs)
mask[rnd_idx] = True
masked_orig_cmc = orig_cmc[mask]
_cmc = masked_orig_cmc.cumsum()
_cmc[_cmc > 1] = 1
cmc += _cmc[:max_rank].astype(np.float32)
# compute AP
num_rel = masked_orig_cmc.sum()
tmp_cmc = masked_orig_cmc.cumsum()
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc
AP += tmp_cmc.sum() / num_rel
cmc /= N
AP /= N
all_cmc.append(cmc)
all_AP.append(AP)
num_valid_q += 1.
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
if __name__ == '__main__':
result = scipy.io.loadmat('./result_cuhk03.mat')
distmat = result['distmat']
q_pids = np.squeeze(result['q_pids'])
g_pids = np.squeeze(result['g_pids'])
q_camids = np.squeeze(result['q_camids'])
g_camids = np.squeeze(result['g_camids'])
cmc, mAP = eval_cuhk03(distmat, q_pids=q_pids, g_pids=g_pids,
q_camids=q_camids, g_camids=g_camids, max_rank=len(g_pids))
print('Rank@1:%f Rank@5:%f Rank@10:%f Rank@20:%f mAP:%f' % (cmc[0], cmc[4], cmc[9], cmc[19], mAP))
#scipy.io.savemat('resnet.mat', {'CMC': cmc})