From fb4d88449815402e2f2fdd0692478866eb20a1f0 Mon Sep 17 00:00:00 2001 From: zhaobenx Date: Wed, 11 Apr 2018 18:22:53 +0800 Subject: [PATCH] Add batch method and try to use k means --- .../orb\350\247\243\346\236\220.md" | 4 +- py/batch_stich.py | 44 ++++++++ py/k_means.py | 100 ++++++++++++++++++ py/stich.py | 50 ++++++--- 4 files changed, 184 insertions(+), 14 deletions(-) create mode 100644 py/batch_stich.py create mode 100644 py/k_means.py diff --git "a/doc/orb\350\247\243\346\236\220/orb\350\247\243\346\236\220.md" "b/doc/orb\350\247\243\346\236\220/orb\350\247\243\346\236\220.md" index b963457..bbe85f6 100644 --- "a/doc/orb\350\247\243\346\236\220/orb\350\247\243\346\236\220.md" +++ "b/doc/orb\350\247\243\346\236\220/orb\350\247\243\346\236\220.md" @@ -72,6 +72,7 @@ S_{\theta} = R_{\theta}S $$ 其中的$S$为选取的点对($x_n$与$y_n$是一个点对两个点的灰度值),$R_{\theta}$为旋转不变性(公式(4))求得的旋转角度。 + $$ S =\begin{pmatrix}x_1&x_2&\cdots&x_{2n} \\ y_1&y_2&\cdots&y_{2n}\end{pmatrix} $$ @@ -176,6 +177,7 @@ HarrisResponses(const Mat& img, const std::vector& layerinfo, ``` 代码最后几行即为计算我们公式(1)的具体计算,其中用到了这样的公式,即在矩阵$\boldsymbol{M}=\begin{bmatrix}A&C\\C&B\end{bmatrix}$中: + $$ det\boldsymbol{M} = \lambda_1\lambda_2=AB-C^2 $$ @@ -293,4 +295,4 @@ computeOrbDescriptors( const Mat& imagePyramid, const std::vector& layerIn -(三十二个为一个描述子,图中并未截全一个描述子) \ No newline at end of file +(三十二个为一个描述子,图中并未截全一个描述子)s \ No newline at end of file diff --git a/py/batch_stich.py b/py/batch_stich.py new file mode 100644 index 0000000..20bf3fb --- /dev/null +++ b/py/batch_stich.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" +Created on 2018-04-10 20:11:11 +@Author: ZHAO Lingfeng +@Version : 0.0.1 +""" +import glob +import os + +import cv2 +from stich import Matcher, Sticher, Method + + +def main(): + import time + # main() + os.chdir(os.path.dirname(__file__)) + + number = 9 + file1 = "../resource/{}-right*.jpg".format(number) + file2 = "../resource/{}-left.jpg".format(number) + + start_time = time.time() + for method in (Method.SIFT, Method.ORB): + + for f in glob.glob(file1): + print(f) + name = f.replace('right', method.name) + # print(file2, name) + + img2 = cv2.imread(file2) + img1 = cv2.imread(f) + matcher = Matcher(img1, img2, method=method) + matcher.match(max_match_lenth=20, show_match=False,) + sticher = Sticher(img1, img2, matcher) + sticher.stich(show_result=False) + cv2.imwrite(name, sticher.image) + print("Time: ", time.time() - start_time) + # print("M: ", sticher.M) + print('\a') + + +if __name__ == "__main__": + main() diff --git a/py/k_means.py b/py/k_means.py new file mode 100644 index 0000000..0252f6b --- /dev/null +++ b/py/k_means.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +""" +Created on 2018-04-02 22:12:12 +@Author: ZHAO Lingfeng +@Version : 0.0.1 +""" +from typing import Tuple + +import cv2 +import numpy as np + + +def difference(array: np.ndarray): + x = [] + for i in range(len(array) - 1): + x.append(array[i + 1] - array[i]) + + return np.array(x) + + +def find_peek(array: np.ndarray): + peek = difference(difference(array)) + # print(peek) + peek_pos = np.argmax(peek) + 2 + return peek_pos + + +def k_means(points: np.ndarray): + """返回一个数组经kmeans分类后的k值以及标签,k值由计算拐点给出 + + Args: + points (np.ndarray): 需分类数据 + + Returns: + Tuple[int, np.ndarry]: k值以及标签数组 + """ + + # Define criteria = ( type, max_iter = 10 , epsilon = 1.0 ) + criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) + + # Set flags (Just to avoid line break in the code) + flags = cv2.KMEANS_RANDOM_CENTERS + length = [] + max_k = min(10, points.shape[0]) + for k in range(2, max_k + 1): + avg = 0 + for i in range(5): + compactness, _, _ = cv2.kmeans( + points, k, None, criteria, 10, flags) + avg += compactness + avg /= 5 + length.append(avg) + + peek_pos = find_peek(length) + k = peek_pos + 2 + # print(k) + return k, cv2.kmeans(points, k, None, criteria, 10, flags)[1] # labels + + +def get_group_center(points1: np.ndarray, points2: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """输入两个相对应的点对数组,返回经kmeans优化后的两个数组 + + Args: + points1 (np.ndarray): 数组一 + points2 (np.ndarray): 数组二 + + Returns: + Tuple[np.ndarray, np.ndarray]: 两数组 + """ + + k, labels = k_means(points1) + labels = labels.flatten() + selected_centers1 = [] + selected_centers2 = [] + for i in range(k): + center1 = np.mean(points1[labels == i], axis=0) + center2 = np.mean(points2[labels == i], axis=0) + # center1 = points1[labels == i][0] + # center2 = points2[labels == i][0] + + selected_centers1.append(center1) + selected_centers2.append(center2) + + selected_centers1, selected_centers2 = np.array( + selected_centers1), np.array(selected_centers2) + + # return selected_centers1, selected_centers2 + # return np.append(selected_centers1, points1, axis=0), np.append(selected_centers2, points2, axis=0) + return points1, points2 + + +def main(): + x = np.array([[1, 1], [1, 2], [2, 2], [3, 3]], dtype=np.float32) + y = np.array([[1, 1], [1, 2], [2, 2], [3, 3]], dtype=np.float32) + print(get_group_center(x, y)) + pass + + +if __name__ == "__main__": + main() diff --git a/py/stich.py b/py/stich.py index ca52284..2465606 100644 --- a/py/stich.py +++ b/py/stich.py @@ -14,6 +14,8 @@ import cv2 import numpy as np +import k_means + def show_image(image: np.ndarray) -> None: from PIL import Image @@ -102,6 +104,8 @@ def match(self, max_match_lenth=20, threshold=0.04, show_match=False): self._descriptors1, self._descriptors2), key=lambda x: x.distance) match_len = min(len(self.match_points), max_match_lenth) + if self.method == Method.ORB: + threshold = 20 min_distance = max(2 * self.match_points[0].distance, threshold) for i in range(match_len): @@ -130,9 +134,20 @@ def match(self, max_match_lenth=20, threshold=0.04, show_match=False): # print(image_points1) +def get_weighted_points(image_points: np.ndarray): + + # print(k_means.k_means(image_points)) + # exit(0) + + average = np.average(image_points, axis=0) + + max_index = np.argmax(np.linalg.norm((image_points - average), axis=1)) + return np.append(image_points, np.array([image_points[max_index]]), axis=0) + + class Sticher: - def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher): + def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher, use_kmeans=False): """输入图像和匹配,对图像进行拼接 目前采用简单矩阵匹配和平均值拼合 @@ -140,12 +155,17 @@ def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher): image1 (np.ndarray): 图像一 image2 (np.ndarray): 图像二 matcher (Matcher): 匹配结果 + use_kmeans (bool): 是否使用kmeans 优化点选择 """ self.image1 = image1 self.image2 = image2 - self.image_points1 = matcher.image_points1 - self.image_points2 = matcher.image_points2 + if use_kmeans: + self.image_points1, self.image_points2 = k_means.get_group_center( + matcher.image_points1, matcher.image_points2) + else: + self.image_points1, self.image_points2 = ( + matcher.image_points1, matcher.image_points2) self.M = np.eye(3) @@ -165,7 +185,8 @@ def stich(self, show_result=True, show_match_point=True): # print(self.get_transformed_size()) width = int(max(right, self.image2.shape[1]) - min(left, 0)) height = int(max(bottom, self.image2.shape[0]) - min(top, 0)) - # print(width, height) + print(width, height) + width, height = min(width, 10000), min(height, 10000) # 移动矩阵 adjustM = np.array( @@ -186,11 +207,11 @@ def stich(self, show_result=True, show_match_point=True): for point in self.image_points1: point = self.get_transformed_position(tuple(point)) point = tuple(map(int, point)) - cv2.circle(self.image, point, 10, (20, 20, 255)) + cv2.circle(self.image, point, 10, (20, 20, 255), 5) for point in self.image_points2: point = self.get_transformed_position(tuple(point), M=adjustM) point = tuple(map(int, point)) - cv2.circle(self.image, point, 8, (20, 200, 20)) + cv2.circle(self.image, point, 8, (20, 200, 20), 5) if show_result: show_image(self.image) @@ -209,7 +230,7 @@ def blend(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray: # result[0:image2.shape[0], 0:self.image2.shape[1]] = self.image2 # result = np.maximum(transformed, result) - # result = cv2.addWeighted(transformed, 0.5, result, 0.5, 1) + # result = cv2.addWeighted(image1, 0.5, image2, 0.5, 1) result = self.average(image1, image2) return result @@ -235,7 +256,8 @@ def average(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray: ) # 重叠处用平均值 result[overlap] = np.average( - np.array([image1[overlap], image2[overlap]]), axis=0) .astype(np.uint8) + np.array([image1[overlap], image2[overlap]]), axis=0 + ) .astype(np.uint8) # 非重叠处采选最大值 not_overlap = np.logical_not(overlap) result[not_overlap] = np.maximum( @@ -365,12 +387,14 @@ def main(): os.chdir(os.path.dirname(__file__)) start_time = time.time() - img1 = cv2.imread("../resource/3-left.jpg") - img2 = cv2.imread("../resource/3-right.jpg") + img2 = cv2.imread("../resource/15-left.jpg") + img1 = cv2.imread("../resource/15-right.jpg") matcher = Matcher(img1, img2, Method.ORB) - matcher.match(show_match=True) - sticher = Sticher(img1, img2, matcher) + matcher.match(max_match_lenth=20, show_match=True,) + sticher = Sticher(img1, img2, matcher, False) sticher.stich() - cv2.imwrite('../resource/3-orb.jpg', sticher.image) + + # cv2.imwrite('../resource/15-orb.jpg', sticher.image) + print("Time: ", time.time() - start_time) print("M: ", sticher.M)