Skip to content

Commit

Permalink
Force MKL TruncatedSVD to float
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 30, 2024
1 parent 1bcf4dc commit 41e46fa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
15 changes: 13 additions & 2 deletions inferelator_velocity/utils/_truncated_mkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ def randomized_range_finder(
# Perform power iterations with Q to further 'imprint' the top
# singular vectors of A in Q
for _ in range(n_iter):
Q, _ = normalizer(safe_sparse_dot(A, Q, dense_output=True))
Q, _ = normalizer(
safe_sparse_dot(
A,
Q,
dense_output=True
)
)
Q, _ = normalizer(
safe_sparse_dot(
np.ascontiguousarray(Q.T),
Expand Down Expand Up @@ -219,7 +225,12 @@ def randomized_svd(
class TruncatedSVDMKL(TruncatedSVD):

def fit_transform(self, X, y=None):
X = self._validate_data(X, accept_sparse=["csr", "csc"], ensure_min_features=2)
X = self._validate_data(
X,
accept_sparse=["csr", "csc"],
ensure_min_features=2,
dtype=[np.float64, np.float32, float]
)
random_state = check_random_state(self.random_state)

if self.algorithm == "arpack":
Expand Down
3 changes: 1 addition & 2 deletions inferelator_velocity/utils/mcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from inferelator_velocity.utils.math import (
pairwise_metric,
mcv_mse,
array_sum,
sparse_dot_patch
array_sum
)

try:
Expand Down

0 comments on commit 41e46fa

Please sign in to comment.