-
Notifications
You must be signed in to change notification settings - Fork 0
/
omp.py
136 lines (113 loc) · 6.36 KB
/
omp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
'''
Based on
https://github.com/Ariel5/omp-parallel-gpu-python
'''
import torch
def innerp(x, y=None, out=None):
if y is None:
y = x
if out is not None:
out = out[:, None, None] # Add space for two singleton dimensions.
return torch.matmul(x[..., None, :], y[..., :, None], out=out)[..., 0, 0]
def cholesky_solve(ATA, ATy):
if ATA.dtype == torch.half or ATy.dtype == torch.half:
return ATy.to(torch.float).cholesky_solve(torch.cholesky(ATA.to(torch.float))).to(ATy.dtype)
return ATy.cholesky_solve(torch.cholesky(ATA)).to(ATy.dtype)
def omp_v0(X, y, XTX, n_nonzero_coefs=None, tol=None, inverse_cholesky=True):
B = y.shape[0]
normr2 = innerp(y) # Norm squared of residual.
projections = (X.transpose(1, 0) @ y[:, :, None]).squeeze(-1)
sets = y.new_zeros(n_nonzero_coefs, B, dtype=torch.int64)
if inverse_cholesky:
# Doing the inverse-cholesky iteratively uses more memory,
# but takes less time than waiting till solving the problem in the end it seems.
# (Since F is triangular it could be __even faster__ to multiply, prob. not on GPU tho.)
F = torch.eye(n_nonzero_coefs, dtype=y.dtype, device=y.device).repeat(B, 1, 1)
a_F = y.new_zeros(n_nonzero_coefs, B, 1)
D_mybest = y.new_empty(B, n_nonzero_coefs, XTX.shape[0])
temp_F_k_k = y.new_ones((B, 1))
if tol:
result_lengths = sets.new_zeros(y.shape[0])
result_solutions = y.new_zeros((y.shape[0], n_nonzero_coefs, 1))
finished_problems = sets.new_zeros(y.shape[0], dtype=torch.bool)
for k in range(n_nonzero_coefs+bool(tol)):
# STOPPING CRITERIA
if tol:
problems_done = normr2 <= tol
if k == n_nonzero_coefs:
problems_done[:] = True
if problems_done.any():
new_problems_done = problems_done & ~finished_problems
finished_problems.logical_or_(problems_done)
result_lengths[new_problems_done] = k
if inverse_cholesky:
result_solutions[new_problems_done, :k] = F[new_problems_done, :k, :k].permute(0, 2, 1) @ a_F[:k, new_problems_done].permute(1, 0, 2)
else:
assert False, "inverse_cholesky=False with tol != None is not handled yet"
if problems_done.all():
return sets.t(), result_solutions, result_lengths
sets[k] = projections.abs().argmax(1)
# D_mybest[:, k, :] = XTX[gamma[k], :] # Same line as below, but significantly slower. (prob. due to the intermediate array creation)
torch.gather(XTX, 0, sets[k, :, None].expand(-1, XTX.shape[1]), out=D_mybest[:, k, :])
if k:
D_mybest_maxindices = D_mybest.permute(0, 2, 1)[torch.arange(D_mybest.shape[0], dtype=sets.dtype, device=sets.device), sets[k], :k]
torch.rsqrt(abs(1 - innerp(D_mybest_maxindices)),
out=temp_F_k_k[:, 0]) # torch.exp(-1/2 * torch.log1p(-inp), temp_F_k_k[:, 0])
D_mybest_maxindices *= -temp_F_k_k # minimal operations, exploit linearity
D_mybest[:, k, :] *= temp_F_k_k
D_mybest[:, k, :, None].baddbmm_(D_mybest[:, :k, :].permute(0, 2, 1), D_mybest_maxindices[:, :, None])
temp_a_F = temp_F_k_k * torch.gather(projections, 1, sets[k, :, None])
normr2 -= (temp_a_F * temp_a_F).squeeze(-1)
projections -= temp_a_F * D_mybest[:, k, :]
if inverse_cholesky:
a_F[k] = temp_a_F
if k: # Could maybe get a speedup from triangular mat mul kernel.
torch.bmm(D_mybest_maxindices[:, None, :], F[:, :k, :], out=F[:, k, None, :])
F[:, k, k] = temp_F_k_k[..., 0]
else: # FIXME: else branch will not execute if n_nonzero_coefs=0, so solutions is undefined.
# Normal exit, used when tolerance=None.
if inverse_cholesky:
solutions = F.permute(0, 2, 1) @ a_F.squeeze(-1).transpose(1, 0)[:, :, None]
else:
# Solving the problem in the end without using inverse Cholesky.
AT = X.T[sets.T]
solutions = cholesky_solve(AT @ AT.permute(0, 2, 1), AT @ y.T[:, :, None])
return sets.t(), solutions, None
def run_omp(X, y, n_nonzero_coefs, precompute=True, tol=0.0, normalize=False, fit_intercept=False, alg='naive'):
if not isinstance(X, torch.Tensor):
X = torch.as_tensor(X)
y = torch.as_tensor(y)
# We can either return sets, (sets, solutions), or xests
# These are all equivalent, but are simply more and more dense representations.
# Given sets and X and y one can (re-)construct xests. The second is just a sparse vector repr.
# https://github.com/scikit-learn/scikit-learn/blob/15a949460dbf19e5e196b8ef48f9712b72a3b3c3/sklearn/linear_model/_omp.py#L690
if fit_intercept or normalize:
X = X.clone()
assert not isinstance(precompute, torch.Tensor), "If user pre-computes XTX they can also pre-normalize X" \
" as well, so normalize and fit_intercept must be set false."
if fit_intercept:
X = X - X.mean(0)
y = y - y.mean(1)[:, None]
# To keep a good condition number on X, especially with Cholesky compared to LU factorization,
# we should probably always normalize it (OMP is invariant anyways)
if normalize is True: # User can also just optionally supply pre-computed norms.
normalize = (X * X).sum(0).sqrt()
X /= normalize[None, :]
if precompute is True or alg == 'v0':
precompute = X.T @ X
# If n_nonzero_coefs is equal to M, one should just return lstsq
if alg == 'v0':
sets, solutions, lengths = omp_v0(X, y, n_nonzero_coefs=n_nonzero_coefs, XTX=precompute, tol=tol)
elif alg == 'naive':
sets, solutions, lengths = omp_naive(X, y, n_nonzero_coefs=n_nonzero_coefs, XTX=precompute, tol=tol)
solutions = solutions.squeeze(-1)
if normalize is not False:
solutions /= normalize[sets]
xests = y.new_zeros(y.shape[0], X.shape[1])
if lengths is None:
xests[torch.arange(y.shape[0], dtype=sets.dtype, device=sets.device)[:, None], sets] = solutions
else:
for i in range(y.shape[0]):
# xests[i].scatter_(-1, sets[i, :lengths[i]], solutions[i, :lengths[i]])
xests[i, sets[i, :lengths[i]]] = solutions[i, :lengths[i]]
return xests