-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add batch method and try to use k means
- Loading branch information
Showing
4 changed files
with
184 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on 2018-04-10 20:11:11 | ||
@Author: ZHAO Lingfeng | ||
@Version : 0.0.1 | ||
""" | ||
import glob | ||
import os | ||
|
||
import cv2 | ||
from stich import Matcher, Sticher, Method | ||
|
||
|
||
def main(): | ||
import time | ||
# main() | ||
os.chdir(os.path.dirname(__file__)) | ||
|
||
number = 9 | ||
file1 = "../resource/{}-right*.jpg".format(number) | ||
file2 = "../resource/{}-left.jpg".format(number) | ||
|
||
start_time = time.time() | ||
for method in (Method.SIFT, Method.ORB): | ||
|
||
for f in glob.glob(file1): | ||
print(f) | ||
name = f.replace('right', method.name) | ||
# print(file2, name) | ||
|
||
img2 = cv2.imread(file2) | ||
img1 = cv2.imread(f) | ||
matcher = Matcher(img1, img2, method=method) | ||
matcher.match(max_match_lenth=20, show_match=False,) | ||
sticher = Sticher(img1, img2, matcher) | ||
sticher.stich(show_result=False) | ||
cv2.imwrite(name, sticher.image) | ||
print("Time: ", time.time() - start_time) | ||
# print("M: ", sticher.M) | ||
print('\a') | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on 2018-04-02 22:12:12 | ||
@Author: ZHAO Lingfeng | ||
@Version : 0.0.1 | ||
""" | ||
from typing import Tuple | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
|
||
def difference(array: np.ndarray): | ||
x = [] | ||
for i in range(len(array) - 1): | ||
x.append(array[i + 1] - array[i]) | ||
|
||
return np.array(x) | ||
|
||
|
||
def find_peek(array: np.ndarray): | ||
peek = difference(difference(array)) | ||
# print(peek) | ||
peek_pos = np.argmax(peek) + 2 | ||
return peek_pos | ||
|
||
|
||
def k_means(points: np.ndarray): | ||
"""返回一个数组经kmeans分类后的k值以及标签,k值由计算拐点给出 | ||
Args: | ||
points (np.ndarray): 需分类数据 | ||
Returns: | ||
Tuple[int, np.ndarry]: k值以及标签数组 | ||
""" | ||
|
||
# Define criteria = ( type, max_iter = 10 , epsilon = 1.0 ) | ||
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) | ||
|
||
# Set flags (Just to avoid line break in the code) | ||
flags = cv2.KMEANS_RANDOM_CENTERS | ||
length = [] | ||
max_k = min(10, points.shape[0]) | ||
for k in range(2, max_k + 1): | ||
avg = 0 | ||
for i in range(5): | ||
compactness, _, _ = cv2.kmeans( | ||
points, k, None, criteria, 10, flags) | ||
avg += compactness | ||
avg /= 5 | ||
length.append(avg) | ||
|
||
peek_pos = find_peek(length) | ||
k = peek_pos + 2 | ||
# print(k) | ||
return k, cv2.kmeans(points, k, None, criteria, 10, flags)[1] # labels | ||
|
||
|
||
def get_group_center(points1: np.ndarray, points2: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | ||
"""输入两个相对应的点对数组,返回经kmeans优化后的两个数组 | ||
Args: | ||
points1 (np.ndarray): 数组一 | ||
points2 (np.ndarray): 数组二 | ||
Returns: | ||
Tuple[np.ndarray, np.ndarray]: 两数组 | ||
""" | ||
|
||
k, labels = k_means(points1) | ||
labels = labels.flatten() | ||
selected_centers1 = [] | ||
selected_centers2 = [] | ||
for i in range(k): | ||
center1 = np.mean(points1[labels == i], axis=0) | ||
center2 = np.mean(points2[labels == i], axis=0) | ||
# center1 = points1[labels == i][0] | ||
# center2 = points2[labels == i][0] | ||
|
||
selected_centers1.append(center1) | ||
selected_centers2.append(center2) | ||
|
||
selected_centers1, selected_centers2 = np.array( | ||
selected_centers1), np.array(selected_centers2) | ||
|
||
# return selected_centers1, selected_centers2 | ||
# return np.append(selected_centers1, points1, axis=0), np.append(selected_centers2, points2, axis=0) | ||
return points1, points2 | ||
|
||
|
||
def main(): | ||
x = np.array([[1, 1], [1, 2], [2, 2], [3, 3]], dtype=np.float32) | ||
y = np.array([[1, 1], [1, 2], [2, 2], [3, 3]], dtype=np.float32) | ||
print(get_group_center(x, y)) | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters