From 1d514d127dfaee528750301fc6dd5abdffbad1f2 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Mon, 25 Sep 2023 23:15:01 +0200 Subject: [PATCH] Fix box intersection logic --- chebai/models/box_eval.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/chebai/models/box_eval.py b/chebai/models/box_eval.py index 6b6b7dd3..fa094b73 100644 --- a/chebai/models/box_eval.py +++ b/chebai/models/box_eval.py @@ -2,6 +2,7 @@ n = len(boxes) containment_matrix = np.zeros((n, n), dtype=bool) +threshold = 0.99 for i in range(n): for j in range(n): @@ -9,27 +10,19 @@ box1 = boxes[i] box2 = boxes[j] - min_corners_box_1 = np.minimum(box1[0], box1[1]) - max_corners_box_1 = np.maximum(box1[0], box1[1]) + min_corners_box_1 = np.min(box1, axis=-2) + max_corners_box_1 = np.max(box1, axis=-2) - min_corners_box_2 = np.minimum(box2[0], box2[1]) - max_corners_box_2 = np.maximum(box2[0], box2[1]) + vol_box_1 = np.prod(max_corners_box_1 - min_corners_box_1) - dim = len(min_corners_box_1) + min_corners_box_2 = np.min(box2, axis=-2) + max_corners_box_2 = np.max(box2, axis=-2) - membership_per_dim = [] - for d in range(dim): - a = max(min_corners_box_1[d], min_corners_box_2[d]) - b = min(max_corners_box_1[d], max_corners_box_2[d]) - intersection = (a <= b) * (b - a) - size_of_a = min_corners_box_1[d] + max_corners_box_1[d] + a = np.maximum(min_corners_box_1, min_corners_box_2) # right face of intersection + b = np.minimum(max_corners_box_1, max_corners_box_2) # left face of intersection - # if box_1 is not contained in box_2, then is_contained is zero + intersection_per_dim = (a <= b) * (b - a) + vol_intersection = np.prod(intersection_per_dim) + box_1_is_contained_in_box_2 = vol_intersection / vol_box_1 - is_contained = abs(intersection / size_of_a) - membership_per_dim.append(1 if is_contained else 0) - - - count = sum(1 for item in membership_per_dim if item == 1) - box_1_is_contained_in_box_2 = (count == 10) - containment_matrix[i][j] = box_1_is_contained_in_box_2 \ No newline at end of file + containment_matrix[i][j] = box_1_is_contained_in_box_2 > threshold