Skip to content

Commit

Permalink
Add batch method and try to use k means
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaobenx committed Apr 11, 2018
1 parent e62aaf2 commit fb4d884
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 14 deletions.
4 changes: 3 additions & 1 deletion doc/orb解析/orb解析.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ S_{\theta} = R_{\theta}S
$$

其中的$S$为选取的点对($x_n$与$y_n$是一个点对两个点的灰度值),$R_{\theta}$为旋转不变性(公式(4))求得的旋转角度。

$$
S =\begin{pmatrix}x_1&x_2&\cdots&x_{2n} \\ y_1&y_2&\cdots&y_{2n}\end{pmatrix}
$$
Expand Down Expand Up @@ -176,6 +177,7 @@ HarrisResponses(const Mat& img, const std::vector<Rect>& layerinfo,
```
代码最后几行即为计算我们公式(1)的具体计算,其中用到了这样的公式,即在矩阵$\boldsymbol{M}=\begin{bmatrix}A&C\\C&B\end{bmatrix}$中:
$$
det\boldsymbol{M} = \lambda_1\lambda_2=AB-C^2
$$
Expand Down Expand Up @@ -293,4 +295,4 @@ computeOrbDescriptors( const Mat& imagePyramid, const std::vector<Rect>& layerIn
(三十二个为一个描述子,图中并未截全一个描述子)
(三十二个为一个描述子,图中并未截全一个描述子)s
44 changes: 44 additions & 0 deletions py/batch_stich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""
Created on 2018-04-10 20:11:11
@Author: ZHAO Lingfeng
@Version : 0.0.1
"""
import glob
import os

import cv2
from stich import Matcher, Sticher, Method


def main():
import time
# main()
os.chdir(os.path.dirname(__file__))

number = 9
file1 = "../resource/{}-right*.jpg".format(number)
file2 = "../resource/{}-left.jpg".format(number)

start_time = time.time()
for method in (Method.SIFT, Method.ORB):

for f in glob.glob(file1):
print(f)
name = f.replace('right', method.name)
# print(file2, name)

img2 = cv2.imread(file2)
img1 = cv2.imread(f)
matcher = Matcher(img1, img2, method=method)
matcher.match(max_match_lenth=20, show_match=False,)
sticher = Sticher(img1, img2, matcher)
sticher.stich(show_result=False)
cv2.imwrite(name, sticher.image)
print("Time: ", time.time() - start_time)
# print("M: ", sticher.M)
print('\a')


if __name__ == "__main__":
main()
100 changes: 100 additions & 0 deletions py/k_means.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
"""
Created on 2018-04-02 22:12:12
@Author: ZHAO Lingfeng
@Version : 0.0.1
"""
from typing import Tuple

import cv2
import numpy as np


def difference(array: np.ndarray):
x = []
for i in range(len(array) - 1):
x.append(array[i + 1] - array[i])

return np.array(x)


def find_peek(array: np.ndarray):
peek = difference(difference(array))
# print(peek)
peek_pos = np.argmax(peek) + 2
return peek_pos


def k_means(points: np.ndarray):
"""返回一个数组经kmeans分类后的k值以及标签,k值由计算拐点给出
Args:
points (np.ndarray): 需分类数据
Returns:
Tuple[int, np.ndarry]: k值以及标签数组
"""

# Define criteria = ( type, max_iter = 10 , epsilon = 1.0 )
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)

# Set flags (Just to avoid line break in the code)
flags = cv2.KMEANS_RANDOM_CENTERS
length = []
max_k = min(10, points.shape[0])
for k in range(2, max_k + 1):
avg = 0
for i in range(5):
compactness, _, _ = cv2.kmeans(
points, k, None, criteria, 10, flags)
avg += compactness
avg /= 5
length.append(avg)

peek_pos = find_peek(length)
k = peek_pos + 2
# print(k)
return k, cv2.kmeans(points, k, None, criteria, 10, flags)[1] # labels


def get_group_center(points1: np.ndarray, points2: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""输入两个相对应的点对数组,返回经kmeans优化后的两个数组
Args:
points1 (np.ndarray): 数组一
points2 (np.ndarray): 数组二
Returns:
Tuple[np.ndarray, np.ndarray]: 两数组
"""

k, labels = k_means(points1)
labels = labels.flatten()
selected_centers1 = []
selected_centers2 = []
for i in range(k):
center1 = np.mean(points1[labels == i], axis=0)
center2 = np.mean(points2[labels == i], axis=0)
# center1 = points1[labels == i][0]
# center2 = points2[labels == i][0]

selected_centers1.append(center1)
selected_centers2.append(center2)

selected_centers1, selected_centers2 = np.array(
selected_centers1), np.array(selected_centers2)

# return selected_centers1, selected_centers2
# return np.append(selected_centers1, points1, axis=0), np.append(selected_centers2, points2, axis=0)
return points1, points2


def main():
x = np.array([[1, 1], [1, 2], [2, 2], [3, 3]], dtype=np.float32)
y = np.array([[1, 1], [1, 2], [2, 2], [3, 3]], dtype=np.float32)
print(get_group_center(x, y))
pass


if __name__ == "__main__":
main()
50 changes: 37 additions & 13 deletions py/stich.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import cv2
import numpy as np

import k_means


def show_image(image: np.ndarray) -> None:
from PIL import Image
Expand Down Expand Up @@ -102,6 +104,8 @@ def match(self, max_match_lenth=20, threshold=0.04, show_match=False):
self._descriptors1, self._descriptors2), key=lambda x: x.distance)

match_len = min(len(self.match_points), max_match_lenth)
if self.method == Method.ORB:
threshold = 20
min_distance = max(2 * self.match_points[0].distance, threshold)

for i in range(match_len):
Expand Down Expand Up @@ -130,22 +134,38 @@ def match(self, max_match_lenth=20, threshold=0.04, show_match=False):
# print(image_points1)


def get_weighted_points(image_points: np.ndarray):

# print(k_means.k_means(image_points))
# exit(0)

average = np.average(image_points, axis=0)

max_index = np.argmax(np.linalg.norm((image_points - average), axis=1))
return np.append(image_points, np.array([image_points[max_index]]), axis=0)


class Sticher:

def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher):
def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher, use_kmeans=False):
"""输入图像和匹配,对图像进行拼接
目前采用简单矩阵匹配和平均值拼合
Args:
image1 (np.ndarray): 图像一
image2 (np.ndarray): 图像二
matcher (Matcher): 匹配结果
use_kmeans (bool): 是否使用kmeans 优化点选择
"""

self.image1 = image1
self.image2 = image2
self.image_points1 = matcher.image_points1
self.image_points2 = matcher.image_points2
if use_kmeans:
self.image_points1, self.image_points2 = k_means.get_group_center(
matcher.image_points1, matcher.image_points2)
else:
self.image_points1, self.image_points2 = (
matcher.image_points1, matcher.image_points2)

self.M = np.eye(3)

Expand All @@ -165,7 +185,8 @@ def stich(self, show_result=True, show_match_point=True):
# print(self.get_transformed_size())
width = int(max(right, self.image2.shape[1]) - min(left, 0))
height = int(max(bottom, self.image2.shape[0]) - min(top, 0))
# print(width, height)
print(width, height)
width, height = min(width, 10000), min(height, 10000)

# 移动矩阵
adjustM = np.array(
Expand All @@ -186,11 +207,11 @@ def stich(self, show_result=True, show_match_point=True):
for point in self.image_points1:
point = self.get_transformed_position(tuple(point))
point = tuple(map(int, point))
cv2.circle(self.image, point, 10, (20, 20, 255))
cv2.circle(self.image, point, 10, (20, 20, 255), 5)
for point in self.image_points2:
point = self.get_transformed_position(tuple(point), M=adjustM)
point = tuple(map(int, point))
cv2.circle(self.image, point, 8, (20, 200, 20))
cv2.circle(self.image, point, 8, (20, 200, 20), 5)
if show_result:
show_image(self.image)

Expand All @@ -209,7 +230,7 @@ def blend(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
# result[0:image2.shape[0], 0:self.image2.shape[1]] = self.image2
# result = np.maximum(transformed, result)

# result = cv2.addWeighted(transformed, 0.5, result, 0.5, 1)
# result = cv2.addWeighted(image1, 0.5, image2, 0.5, 1)
result = self.average(image1, image2)

return result
Expand All @@ -235,7 +256,8 @@ def average(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
)
# 重叠处用平均值
result[overlap] = np.average(
np.array([image1[overlap], image2[overlap]]), axis=0) .astype(np.uint8)
np.array([image1[overlap], image2[overlap]]), axis=0
) .astype(np.uint8)
# 非重叠处采选最大值
not_overlap = np.logical_not(overlap)
result[not_overlap] = np.maximum(
Expand Down Expand Up @@ -365,12 +387,14 @@ def main():
os.chdir(os.path.dirname(__file__))

start_time = time.time()
img1 = cv2.imread("../resource/3-left.jpg")
img2 = cv2.imread("../resource/3-right.jpg")
img2 = cv2.imread("../resource/15-left.jpg")
img1 = cv2.imread("../resource/15-right.jpg")
matcher = Matcher(img1, img2, Method.ORB)
matcher.match(show_match=True)
sticher = Sticher(img1, img2, matcher)
matcher.match(max_match_lenth=20, show_match=True,)
sticher = Sticher(img1, img2, matcher, False)
sticher.stich()
cv2.imwrite('../resource/3-orb.jpg', sticher.image)

# cv2.imwrite('../resource/15-orb.jpg', sticher.image)

print("Time: ", time.time() - start_time)
print("M: ", sticher.M)

0 comments on commit fb4d884

Please sign in to comment.