Skip to content

Commit

Permalink
Add partial stich method,(not perfect!
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaobenx committed Apr 12, 2018
1 parent b47fe7c commit ca567e5
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 13 deletions.
4 changes: 2 additions & 2 deletions py/batch_stich.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import os

import cv2
from stich import Matcher, Sticher, Method
from stich import Sticher, Method


def main():
import time
# main()
os.chdir(os.path.dirname(__file__))

number = 9
number = 17
file1 = "../resource/{}-right*.jpg".format(number)
file2 = "../resource/{}-left.jpg".format(number)

Expand Down
165 changes: 154 additions & 11 deletions py/stich.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,21 @@ def match(self, max_match_lenth=20, threshold=0.04, show_match=False):
self._descriptors1, self._descriptors2), key=lambda x: x.distance)

match_len = min(len(self.match_points), max_match_lenth)
if self.method == Method.ORB:
threshold = 20
min_distance = max(2 * self.match_points[0].distance, threshold)

# if self.method == Method.ORB:
# threshold = 20
# elif self.method == Method.SIFT:
# threshold = 20
# max_distance = max(2 * self.match_points[0].distance, threshold)

# in case distance is 0
max_distance = max(2 * self.match_points[0].distance, 2)

for i in range(match_len):
if self.match_points[i].distance > min_distance:
if self.match_points[i].distance > max_distance:
match_len = i
break
print('min distance: ', min_distance)
print('max distance: ', self.match_points[match_len].distance)
print('match_len: ', match_len)
assert(match_len >= 4)
self.match_points = self.match_points[:match_len]
Expand Down Expand Up @@ -167,13 +173,13 @@ def __init__(self, image1: np.ndarray, image2: np.ndarray, method: Enum=Method.S

self.image = None

def stich(self, show_result=True, show_match_point=True):
def stich(self, show_result=True, show_match_point=True, use_partial=False):
"""对图片进行拼合
show_result (bool, optional): Defaults to True. 是否展示拼合图像
show_match_point (bool, optional): Defaults to True. 是否展示拼合点
"""
self.matcher.match(max_match_lenth=20, show_match=show_match_point)
self.matcher.match(max_match_lenth=40, show_match=show_match_point)

if self.use_kmeans:
self.image_points1, self.image_points2 = k_means.get_group_center(
Expand All @@ -191,6 +197,8 @@ def stich(self, show_result=True, show_match_point=True):
height = int(max(bottom, self.image2.shape[0]) - min(top, 0))
print(width, height)
width, height = min(width, 10000), min(height, 10000)
if use_partial:
self.partial_transform()

# 移动矩阵
adjustM = np.array(
Expand Down Expand Up @@ -219,6 +227,141 @@ def stich(self, show_result=True, show_match_point=True):
if show_result:
show_image(self.image)

def partial_transform(self):
def distance(p1, p2):
return np.sqrt(
(p1[0] - p2[0]) * (p1[0] - p2[0]) + (p1[1] - p2[1]) * (p1[1] - p2[1]))
width = self.image1.shape[0]
height = self.image1.shape[1]
offset_x = np.min(self.image_points1[:, 0])
offset_y = np.min(self.image_points1[:, 1])

# width = np.max(self.image_points1[:, 0]) - offset_x
# height = np.max(self.image_points1[:, 1]) - offset_y
x_mid = int((np.max(self.image_points1[:, 0]) + offset_x) / 2)
y_mid = int((np.max(self.image_points1[:, 1]) + offset_y) / 2)

center = [0, 0]
up = x_mid
down = width - x_mid
left = y_mid
right = height - y_mid

ne, se, sw, nw = [], [], [], []
transform_acer = [[center, [up, 0], [up, right]],
[center, [down, 0], [0, right]],
[center, [down, left], [0, left]],
[[up, 0], [up, left], [up, left]]]
transform_acer = [[center, [0, up], [right, up]],
[center, [0, down], [right, 0]],
[center, [left, down], [left, 0]],
[[0, up], [left, up], [left, up]]]
# transform_acer = [[center, up, right],
# [center, down, right],
# [center, down, left],
# [center, up, left]]

# 对点的位置进行分类
for index in range(self.image_points1.shape[0]):
point = self.image_points1[index]
if point[0] > y_mid:
if point[1] > x_mid:
se.append(index)
else:
ne.append(index)
else:
if point[1] > x_mid:
sw.append(index)
else:
nw.append(index)

# 求点最少处位置,排除零
minmum = np.argmin(
list(map(lambda x: len(x) if len(x) > 0 else 65536, [ne, se, sw, nw])))
# 当足够少时
min_part = (ne, se, sw, nw)[minmum]

# debug:
print("minum part: ", minmum, "point len: ", len(
min_part), "|", list(map(len, (ne, se, sw, nw))))
for index in min_part:
point = self.image_points1[index]
cv2.circle(self.image1, tuple(
map(int, point)), 20, (0, 255, 255), 5)

# cv2.circle(self.image1, tuple(map(int, (y_mid, x_mid))),
# 25, (255, 100, 60), 7)

# end debug

if len(min_part) < len(self.image_points1) / 8:
for index in min_part:
point = self.image_points1[index].tolist()
print("Point: ", point)
# maybe can try other value?
if distance(self.get_transformed_position(tuple(point)),
self.image_points2[index]) > 10:
def relevtive_point(p):
return (p[0] - y_mid if p[0] > y_mid else p[0],
p[1] - x_mid if p[1] > x_mid else p[1])
cv2.circle(self.image1, tuple(map(int, point)),
40, (255, 0, 0), 10)
src_point = transform_acer[minmum].copy()
src_point.append(relevtive_point(point))
other_point = self.get_transformed_position(
tuple(self.image_points2[index]), M=np.linalg.inv(self.M))
dest_point = transform_acer[minmum].copy()
dest_point.append(relevtive_point(other_point))

def a(x): return np.array(x, dtype=np.float32)
print(src_point, dest_point)
partial_M = cv2.getPerspectiveTransform(
a(src_point), a(dest_point))

if minmum == 1 or minmum == 2:
boder_0, boder_1 = x_mid, width
else:
boder_0, boder_1 = 0, x_mid
if minmum == 2 or minmum == 3:
boder_2, boder_3 = 0, y_mid
else:
boder_2, boder_3 = y_mid, height

print("Changed:",
"\nM: ", partial_M,
"\npart: ", minmum,
"\ndistance: ", distance(self.get_transformed_position(tuple(point)),
self.image_points2[index])
)
part = self.image1[boder_0:boder_1, boder_2:boder_3]

#
print(boder_0, boder_1, boder_2, boder_3)
for point in transform_acer[minmum]:
print(point)
cv2.circle(part, tuple(
map(int, point)), 40, (220, 200, 200), 10)
for point in src_point:
print(point)
cv2.circle(part, tuple(
map(int, point)), 22, (226, 43, 138), 8)
# for point in dest_point:
# print(point)
# cv2.circle(part, tuple(
# map(int, point)), 25, (20, 97, 199), 8)
#
# cv2.circle(part, tuple(map(int, relevtive_point(point))),
# 40, (255, 0, 0), 10)
show_image(part)

part = cv2.warpPerspective(
part, partial_M, (part.shape[1], part.shape[0]))
cv2.circle(part, tuple(map(int, relevtive_point(other_point))),
40, (20, 97, 199), 6)
show_image(part)
self.image1[boder_0:boder_1, boder_2:boder_3] = part
return

def blend(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""对图像进行融合
Expand Down Expand Up @@ -391,14 +534,14 @@ def main():
os.chdir(os.path.dirname(__file__))

start_time = time.time()
img2 = cv2.imread("../resource/15-left.jpg")
img1 = cv2.imread("../resource/15-right.jpg")
img2 = cv2.imread("../resource/13-left.jpg")
img1 = cv2.imread("../resource/13-right.jpg")
# matcher = Matcher(img1, img2, Method.ORB)
# matcher.match(max_match_lenth=20, show_match=True,)
sticher = Sticher(img1, img2, Method.ORB, False)
sticher.stich()
sticher.stich(use_partial=True)

# cv2.imwrite('../resource/15-orb.jpg', sticher.image)
cv2.imwrite('../resource/13-partial.jpg', sticher.image)

print("Time: ", time.time() - start_time)
print("M: ", sticher.M)

0 comments on commit ca567e5

Please sign in to comment.