diff --git a/docs/changelog.md b/docs/changelog.md index 927d7119..2cf48671 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -10,6 +10,11 @@ Release notes for `quimb`. - [qu.randn](quimb.randn): support `dist="rademacher"`. - support `dist` and other `randn` options in various TN builders. +**Bug fixes:** + +- restore fallback (to `scipy.linalg.svd` with driver='gesvd') behavior for + truncated SVD with numpy backend. + (whats-new-1-7-2)= ## v1.7.2 (2024-01-30) diff --git a/quimb/tensor/decomp.py b/quimb/tensor/decomp.py index 3eb1f6ae..b5fcb574 100644 --- a/quimb/tensor/decomp.py +++ b/quimb/tensor/decomp.py @@ -3,10 +3,12 @@ import functools import operator +import warnings import numpy as np -import scipy.sparse.linalg as spla +import scipy.linalg as scla import scipy.linalg.interpolative as sli +import scipy.sparse.linalg as spla from autoray import ( astype, backend_like, @@ -313,7 +315,6 @@ def _trim_and_renorm_svd_result_numba( return U, None, VH -@svd_truncated.register("numpy") @njit # pragma: no cover def svd_truncated_numba( x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0 @@ -325,6 +326,25 @@ def svd_truncated_numba( ) +@svd_truncated.register("numpy") +def svd_truncated_numpy( + x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0 +): + """Numpy version of ``svd_truncated``, trying the accelerated version + first, then falling back to the more stable scipy version. + """ + try: + return svd_truncated_numba( + x, cutoff, cutoff_mode, max_bond, absorb, renorm + ) + except np.linalg.LinAlgError as e: # pragma: no cover + warnings.warn(f"Got: {e}, falling back to scipy gesvd driver.") + U, s, VH = scla.svd(x, full_matrices=False, lapack_driver="gesvd") + return _trim_and_renorm_svd_result_numba( + U, s, VH, cutoff, cutoff_mode, max_bond, absorb, renorm + ) + + @svd_truncated.register("autoray.lazy") @lazy.core.lazy_cache("svd_truncated") def svd_truncated_lazy(