forked from pierluigiferrari/ssd_keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
matching_utils.py
116 lines (92 loc) · 5.38 KB
/
matching_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
'''
Utilities to match ground truth boxes to anchor boxes.
Copyright (C) 2018 Pierluigi Ferrari
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
from __future__ import division
import numpy as np
def match_bipartite_greedy(weight_matrix):
'''
Returns a bipartite matching according to the given weight matrix.
The algorithm works as follows:
Let the first axis of `weight_matrix` represent ground truth boxes
and the second axis anchor boxes.
The ground truth box that has the greatest similarity with any
anchor box will be matched first, then out of the remaining ground
truth boxes, the ground truth box that has the greatest similarity
with any of the remaining anchor boxes will be matched second, and
so on. That is, the ground truth boxes will be matched in descending
order by maximum similarity with any of the respectively remaining
anchor boxes.
The runtime complexity is O(m^2 * n), where `m` is the number of
ground truth boxes and `n` is the number of anchor boxes.
Arguments:
weight_matrix (array): A 2D Numpy array that represents the weight matrix
for the matching process. If `(m,n)` is the shape of the weight matrix,
it must be `m <= n`. The weights can be integers or floating point
numbers. The matching process will maximize, i.e. larger weights are
preferred over smaller weights.
Returns:
A 1D Numpy array of length `weight_matrix.shape[0]` that represents
the matched index along the second axis of `weight_matrix` for each index
along the first axis.
'''
weight_matrix = np.copy(weight_matrix) # We'll modify this array.
num_ground_truth_boxes = weight_matrix.shape[0]
all_gt_indices = list(range(num_ground_truth_boxes)) # Only relevant for fancy-indexing below.
# This 1D array will contain for each ground truth box the index of
# the matched anchor box.
matches = np.zeros(num_ground_truth_boxes, dtype=np.int)
# In each iteration of the loop below, exactly one ground truth box
# will be matched to one anchor box.
for _ in range(num_ground_truth_boxes):
# Find the maximal anchor-ground truth pair in two steps: First, reduce
# over the anchor boxes and then reduce over the ground truth boxes.
anchor_indices = np.argmax(weight_matrix, axis=1) # Reduce along the anchor box axis.
overlaps = weight_matrix[all_gt_indices, anchor_indices]
ground_truth_index = np.argmax(overlaps) # Reduce along the ground truth box axis.
anchor_index = anchor_indices[ground_truth_index]
matches[ground_truth_index] = anchor_index # Set the match.
# Set the row of the matched ground truth box and the column of the matched
# anchor box to all zeros. This ensures that those boxes will not be matched again,
# because they will never be the best matches for any other boxes.
weight_matrix[ground_truth_index] = 0
weight_matrix[:,anchor_index] = 0
return matches
def match_multi(weight_matrix, threshold):
'''
Matches all elements along the second axis of `weight_matrix` to their best
matches along the first axis subject to the constraint that the weight of a match
must be greater than or equal to `threshold` in order to produce a match.
If the weight matrix contains elements that should be ignored, the row or column
representing the respective elemet should be set to a value below `threshold`.
Arguments:
weight_matrix (array): A 2D Numpy array that represents the weight matrix
for the matching process. If `(m,n)` is the shape of the weight matrix,
it must be `m <= n`. The weights can be integers or floating point
numbers. The matching process will maximize, i.e. larger weights are
preferred over smaller weights.
threshold (float): A float that represents the threshold (i.e. lower bound)
that must be met by a pair of elements to produce a match.
Returns:
Two 1D Numpy arrays of equal length that represent the matched indices. The first
array contains the indices along the first axis of `weight_matrix`, the second array
contains the indices along the second axis.
'''
num_anchor_boxes = weight_matrix.shape[1]
all_anchor_indices = list(range(num_anchor_boxes)) # Only relevant for fancy-indexing below.
# Find the best ground truth match for every anchor box.
ground_truth_indices = np.argmax(weight_matrix, axis=0) # Array of shape (weight_matrix.shape[1],)
overlaps = weight_matrix[ground_truth_indices, all_anchor_indices] # Array of shape (weight_matrix.shape[1],)
# Filter out the matches with a weight below the threshold.
anchor_indices_thresh_met = np.nonzero(overlaps >= threshold)[0]
gt_indices_thresh_met = ground_truth_indices[anchor_indices_thresh_met]
return gt_indices_thresh_met, anchor_indices_thresh_met