Skip to content

Commit

Permalink
BUG: fix cholesky upper decomp for complex dtypes (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley authored Jan 3, 2024
1 parent 874c2ff commit 88814e5
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
else:
from numpy.core.numeric import normalize_axis_tuple

from ._aliases import matmul, matrix_transpose, tensordot, vecdot
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from .._internal import get_xp

# These are in the main NumPy namespace but not in numpy.linalg
Expand Down Expand Up @@ -59,7 +59,10 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
L = xp.linalg.cholesky(x, **kwargs)
if upper:
return get_xp(xp)(matrix_transpose)(L)
U = get_xp(xp)(matrix_transpose)(L)
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
U = xp.conj(U)
return U
return L

# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
Expand Down

0 comments on commit 88814e5

Please sign in to comment.