-
Notifications
You must be signed in to change notification settings - Fork 29
/
knn_matting.py
59 lines (49 loc) · 1.89 KB
/
knn_matting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import sklearn.neighbors
import scipy.sparse
import warnings
nn = 10
def knn_matte(img, trimap, mylambda=100):
[m, n, c] = img.shape
img, trimap = img/255.0, trimap/255.0
foreground = (trimap > 0.99).astype(int)
background = (trimap < 0.01).astype(int)
all_constraints = foreground + background
print('Finding nearest neighbors')
a, b = np.unravel_index(np.arange(m*n), (m, n))
feature_vec = np.append(np.transpose(img.reshape(m*n,c)), [ a, b]/np.sqrt(m*m + n*n), axis=0).T
nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=10, n_jobs=4).fit(feature_vec)
knns = nbrs.kneighbors(feature_vec)[1]
# Compute Sparse A
print('Computing sparse A')
row_inds = np.repeat(np.arange(m*n), 10)
col_inds = knns.reshape(m*n*10)
vals = 1 - np.linalg.norm(feature_vec[row_inds] - feature_vec[col_inds], axis=1)/(c+2)
A = scipy.sparse.coo_matrix((vals, (row_inds, col_inds)),shape=(m*n, m*n))
D_script = scipy.sparse.diags(np.ravel(A.sum(axis=1)))
L = D_script-A
D = scipy.sparse.diags(np.ravel(all_constraints[:,:, 0]))
v = np.ravel(foreground[:,:,0])
c = 2*mylambda*np.transpose(v)
H = 2*(L + mylambda*D)
print('Solving linear system for alpha')
warnings.filterwarnings('error')
alpha = []
try:
alpha = np.minimum(np.maximum(scipy.sparse.linalg.spsolve(H, c), 0), 1).reshape(m, n)
except Warning:
x = scipy.sparse.linalg.lsqr(H, c)
alpha = np.minimum(np.maximum(x[0], 0), 1).reshape(m, n)
return alpha
def main():
img = scipy.misc.imread('donkey.png')[:,:,:3]
trimap = scipy.misc.imread('donkeyTrimap.png')[:,:,:3]
alpha = knn_matte(img, trimap)
scipy.misc.imsave('donkeyAlpha.png', alpha)
plt.title('Alpha Matte')
plt.imshow(alpha, cmap='gray')
plt.show()
if __name__ == '__main__':
import matplotlib.pyplot as plt
import scipy.misc
main()