diff --git a/py/batch_stich.py b/py/batch_stich.py index f9116b6..83cf0fc 100644 --- a/py/batch_stich.py +++ b/py/batch_stich.py @@ -8,7 +8,7 @@ import os import cv2 -from stich import Matcher, Sticher, Method +from stich import Sticher, Method def main(): @@ -16,7 +16,7 @@ def main(): # main() os.chdir(os.path.dirname(__file__)) - number = 9 + number = 17 file1 = "../resource/{}-right*.jpg".format(number) file2 = "../resource/{}-left.jpg".format(number) diff --git a/py/stich.py b/py/stich.py index ae63c22..e4dc410 100644 --- a/py/stich.py +++ b/py/stich.py @@ -104,15 +104,21 @@ def match(self, max_match_lenth=20, threshold=0.04, show_match=False): self._descriptors1, self._descriptors2), key=lambda x: x.distance) match_len = min(len(self.match_points), max_match_lenth) - if self.method == Method.ORB: - threshold = 20 - min_distance = max(2 * self.match_points[0].distance, threshold) + + # if self.method == Method.ORB: + # threshold = 20 + # elif self.method == Method.SIFT: + # threshold = 20 + # max_distance = max(2 * self.match_points[0].distance, threshold) + + # in case distance is 0 + max_distance = max(2 * self.match_points[0].distance, 2) for i in range(match_len): - if self.match_points[i].distance > min_distance: + if self.match_points[i].distance > max_distance: match_len = i break - print('min distance: ', min_distance) + print('max distance: ', self.match_points[match_len].distance) print('match_len: ', match_len) assert(match_len >= 4) self.match_points = self.match_points[:match_len] @@ -167,13 +173,13 @@ def __init__(self, image1: np.ndarray, image2: np.ndarray, method: Enum=Method.S self.image = None - def stich(self, show_result=True, show_match_point=True): + def stich(self, show_result=True, show_match_point=True, use_partial=False): """对图片进行拼合 show_result (bool, optional): Defaults to True. 是否展示拼合图像 show_match_point (bool, optional): Defaults to True. 是否展示拼合点 """ - self.matcher.match(max_match_lenth=20, show_match=show_match_point) + self.matcher.match(max_match_lenth=40, show_match=show_match_point) if self.use_kmeans: self.image_points1, self.image_points2 = k_means.get_group_center( @@ -191,6 +197,8 @@ def stich(self, show_result=True, show_match_point=True): height = int(max(bottom, self.image2.shape[0]) - min(top, 0)) print(width, height) width, height = min(width, 10000), min(height, 10000) + if use_partial: + self.partial_transform() # 移动矩阵 adjustM = np.array( @@ -219,6 +227,141 @@ def stich(self, show_result=True, show_match_point=True): if show_result: show_image(self.image) + def partial_transform(self): + def distance(p1, p2): + return np.sqrt( + (p1[0] - p2[0]) * (p1[0] - p2[0]) + (p1[1] - p2[1]) * (p1[1] - p2[1])) + width = self.image1.shape[0] + height = self.image1.shape[1] + offset_x = np.min(self.image_points1[:, 0]) + offset_y = np.min(self.image_points1[:, 1]) + + # width = np.max(self.image_points1[:, 0]) - offset_x + # height = np.max(self.image_points1[:, 1]) - offset_y + x_mid = int((np.max(self.image_points1[:, 0]) + offset_x) / 2) + y_mid = int((np.max(self.image_points1[:, 1]) + offset_y) / 2) + + center = [0, 0] + up = x_mid + down = width - x_mid + left = y_mid + right = height - y_mid + + ne, se, sw, nw = [], [], [], [] + transform_acer = [[center, [up, 0], [up, right]], + [center, [down, 0], [0, right]], + [center, [down, left], [0, left]], + [[up, 0], [up, left], [up, left]]] + transform_acer = [[center, [0, up], [right, up]], + [center, [0, down], [right, 0]], + [center, [left, down], [left, 0]], + [[0, up], [left, up], [left, up]]] + # transform_acer = [[center, up, right], + # [center, down, right], + # [center, down, left], + # [center, up, left]] + + # 对点的位置进行分类 + for index in range(self.image_points1.shape[0]): + point = self.image_points1[index] + if point[0] > y_mid: + if point[1] > x_mid: + se.append(index) + else: + ne.append(index) + else: + if point[1] > x_mid: + sw.append(index) + else: + nw.append(index) + + # 求点最少处位置,排除零 + minmum = np.argmin( + list(map(lambda x: len(x) if len(x) > 0 else 65536, [ne, se, sw, nw]))) + # 当足够少时 + min_part = (ne, se, sw, nw)[minmum] + + # debug: + print("minum part: ", minmum, "point len: ", len( + min_part), "|", list(map(len, (ne, se, sw, nw)))) + for index in min_part: + point = self.image_points1[index] + cv2.circle(self.image1, tuple( + map(int, point)), 20, (0, 255, 255), 5) + + # cv2.circle(self.image1, tuple(map(int, (y_mid, x_mid))), + # 25, (255, 100, 60), 7) + + # end debug + + if len(min_part) < len(self.image_points1) / 8: + for index in min_part: + point = self.image_points1[index].tolist() + print("Point: ", point) + # maybe can try other value? + if distance(self.get_transformed_position(tuple(point)), + self.image_points2[index]) > 10: + def relevtive_point(p): + return (p[0] - y_mid if p[0] > y_mid else p[0], + p[1] - x_mid if p[1] > x_mid else p[1]) + cv2.circle(self.image1, tuple(map(int, point)), + 40, (255, 0, 0), 10) + src_point = transform_acer[minmum].copy() + src_point.append(relevtive_point(point)) + other_point = self.get_transformed_position( + tuple(self.image_points2[index]), M=np.linalg.inv(self.M)) + dest_point = transform_acer[minmum].copy() + dest_point.append(relevtive_point(other_point)) + + def a(x): return np.array(x, dtype=np.float32) + print(src_point, dest_point) + partial_M = cv2.getPerspectiveTransform( + a(src_point), a(dest_point)) + + if minmum == 1 or minmum == 2: + boder_0, boder_1 = x_mid, width + else: + boder_0, boder_1 = 0, x_mid + if minmum == 2 or minmum == 3: + boder_2, boder_3 = 0, y_mid + else: + boder_2, boder_3 = y_mid, height + + print("Changed:", + "\nM: ", partial_M, + "\npart: ", minmum, + "\ndistance: ", distance(self.get_transformed_position(tuple(point)), + self.image_points2[index]) + ) + part = self.image1[boder_0:boder_1, boder_2:boder_3] + +# + print(boder_0, boder_1, boder_2, boder_3) + for point in transform_acer[minmum]: + print(point) + cv2.circle(part, tuple( + map(int, point)), 40, (220, 200, 200), 10) + for point in src_point: + print(point) + cv2.circle(part, tuple( + map(int, point)), 22, (226, 43, 138), 8) + # for point in dest_point: + # print(point) + # cv2.circle(part, tuple( + # map(int, point)), 25, (20, 97, 199), 8) +# + # cv2.circle(part, tuple(map(int, relevtive_point(point))), + # 40, (255, 0, 0), 10) + show_image(part) + + part = cv2.warpPerspective( + part, partial_M, (part.shape[1], part.shape[0])) + cv2.circle(part, tuple(map(int, relevtive_point(other_point))), + 40, (20, 97, 199), 6) + show_image(part) + self.image1[boder_0:boder_1, boder_2:boder_3] = part + return + def blend(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray: """对图像进行融合 @@ -391,14 +534,14 @@ def main(): os.chdir(os.path.dirname(__file__)) start_time = time.time() - img2 = cv2.imread("../resource/15-left.jpg") - img1 = cv2.imread("../resource/15-right.jpg") + img2 = cv2.imread("../resource/13-left.jpg") + img1 = cv2.imread("../resource/13-right.jpg") # matcher = Matcher(img1, img2, Method.ORB) # matcher.match(max_match_lenth=20, show_match=True,) sticher = Sticher(img1, img2, Method.ORB, False) - sticher.stich() + sticher.stich(use_partial=True) - # cv2.imwrite('../resource/15-orb.jpg', sticher.image) + cv2.imwrite('../resource/13-partial.jpg', sticher.image) print("Time: ", time.time() - start_time) print("M: ", sticher.M)