Skip to content

Commit

Permalink
Refactor to hide Matcher from public
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaobenx committed Apr 11, 2018
1 parent fb4d884 commit b47fe7c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
4 changes: 1 addition & 3 deletions py/batch_stich.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def main():

img2 = cv2.imread(file2)
img1 = cv2.imread(f)
matcher = Matcher(img1, img2, method=method)
matcher.match(max_match_lenth=20, show_match=False,)
sticher = Sticher(img1, img2, matcher)
sticher = Sticher(img1, img2, method=method)
sticher.stich(show_result=False)
cv2.imwrite(name, sticher.image)
print("Time: ", time.time() - start_time)
Expand Down
30 changes: 17 additions & 13 deletions py/stich.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_weighted_points(image_points: np.ndarray):

class Sticher:

def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher, use_kmeans=False):
def __init__(self, image1: np.ndarray, image2: np.ndarray, method: Enum=Method.SURF, use_kmeans=False):
"""输入图像和匹配,对图像进行拼接
目前采用简单矩阵匹配和平均值拼合
Expand All @@ -160,13 +160,9 @@ def __init__(self, image1: np.ndarray, image2: np.ndarray, matcher: Matcher, use

self.image1 = image1
self.image2 = image2
if use_kmeans:
self.image_points1, self.image_points2 = k_means.get_group_center(
matcher.image_points1, matcher.image_points2)
else:
self.image_points1, self.image_points2 = (
matcher.image_points1, matcher.image_points2)

self.method = method
self.use_kmeans = use_kmeans
self.matcher = Matcher(image1, image2, method=method)
self.M = np.eye(3)

self.image = None
Expand All @@ -177,6 +173,14 @@ def stich(self, show_result=True, show_match_point=True):
show_result (bool, optional): Defaults to True. 是否展示拼合图像
show_match_point (bool, optional): Defaults to True. 是否展示拼合点
"""
self.matcher.match(max_match_lenth=20, show_match=show_match_point)

if self.use_kmeans:
self.image_points1, self.image_points2 = k_means.get_group_center(
self.matcher.image_points1, self.matcher.image_points2)
else:
self.image_points1, self.image_points2 = (
self.matcher.image_points1, self.matcher.image_points2)

self.M, _ = cv2.findHomography(
self.image_points1, self.image_points2, cv2.RANSAC)
Expand Down Expand Up @@ -389,12 +393,12 @@ def main():
start_time = time.time()
img2 = cv2.imread("../resource/15-left.jpg")
img1 = cv2.imread("../resource/15-right.jpg")
matcher = Matcher(img1, img2, Method.ORB)
matcher.match(max_match_lenth=20, show_match=True,)
sticher = Sticher(img1, img2, matcher, False)
# matcher = Matcher(img1, img2, Method.ORB)
# matcher.match(max_match_lenth=20, show_match=True,)
sticher = Sticher(img1, img2, Method.ORB, False)
sticher.stich()

# cv2.imwrite('../resource/15-orb.jpg', sticher.image)

print("Time: ", time.time() - start_time)
print("M: ", sticher.M)

0 comments on commit b47fe7c

Please sign in to comment.