diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 97c584be..ce5b55d1 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -11,7 +11,7 @@ else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot +from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg @@ -59,7 +59,10 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: L = xp.linalg.cholesky(x, **kwargs) if upper: - return get_xp(xp)(matrix_transpose)(L) + U = get_xp(xp)(matrix_transpose)(L) + if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): + U = xp.conj(U) + return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.