diff --git a/numba_scipy/linalg/LAPACK.py b/numba_scipy/linalg/LAPACK.py new file mode 100644 index 0000000..f63abd6 --- /dev/null +++ b/numba_scipy/linalg/LAPACK.py @@ -0,0 +1,238 @@ +from numba.extending import get_cython_function_address +from numba.np.linalg import ensure_lapack, _blas_kinds +import ctypes + +_PTR = ctypes.POINTER + +_dbl = ctypes.c_double +_float = ctypes.c_float +_char = ctypes.c_char +_int = ctypes.c_int + +_ptr_float = _PTR(_float) +_ptr_dbl = _PTR(_dbl) +_ptr_char = _PTR(_char) +_ptr_int = _PTR(_int) + + +def _get_float_pointer_for_dtype(blas_dtype): + if blas_dtype in ['s', 'c']: + return _ptr_float + elif blas_dtype in ['d', 'z']: + return _ptr_dbl + + +class _LAPACK: + """ + Functions to return type signatures for wrapped + LAPACK functions. + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def test_blas_kinds(cls, dtype): + return _blas_kinds[dtype] + + @classmethod + def numba_rgees(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}gees' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # JOBVS + _ptr_int, # SORT + _ptr_int, # SELECT + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # SDIM + float_pointer, # WR + float_pointer, # WI + float_pointer, # VS + _ptr_int, # LDVS + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # BWORK + _ptr_int) # INFO + return functype(addr) + + @classmethod + def numba_cgees(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}gees' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # JOBVS + _ptr_int, # SORT + _ptr_int, # SELECT + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # SDIM + float_pointer, # W + float_pointer, # VS + _ptr_int, # LDVS + float_pointer, # WORK + _ptr_int, # LWORK + float_pointer, # RWORK + _ptr_int, # BWORK + _ptr_int) # INFO + return functype(addr) + + @classmethod + def numba_rgges(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}gges' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # JOBVSL + _ptr_int, # JOBVSR + _ptr_int, # SORT + _ptr_int, # SELCTG + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # SDIM + float_pointer, # ALPHAR + float_pointer, # ALPHAI + float_pointer, # BETA + float_pointer, # VSL + _ptr_int, # LDVSL + float_pointer, # VSR + _ptr_int, # LDVSR + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # BWORK + _ptr_int) # INFO + return functype(addr) + + @classmethod + def numba_cgges(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}gges' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # JOBVSL + _ptr_int, # JOBVSR + _ptr_int, # SORT + _ptr_int, # SELCTG + _ptr_int, # N + float_pointer, # A, complex + _ptr_int, # LDA + float_pointer, # B, complex + _ptr_int, # LDB + _ptr_int, # SDIM + float_pointer, # ALPHA, complex + float_pointer, # BETA, complex + float_pointer, # VSL, complex + _ptr_int, # LDVSL + float_pointer, # VSR, complex + _ptr_int, # LDVSR + float_pointer, # WORK, complex + _ptr_int, # LWORK + float_pointer, # RWORK + _ptr_int, # BWORK + _ptr_int) # INFO + return functype(addr) + + @classmethod + def numba_rtgsen(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}tgsen' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # IJOB + _ptr_int, # WANTQ + _ptr_int, # WANTZ + _ptr_int, # SELECT + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + float_pointer, # ALPHAR + float_pointer, # ALPHAI + float_pointer, # BETA + float_pointer, # Q + _ptr_int, # LDQ + float_pointer, # Z + _ptr_int, # LDZ + _ptr_int, # M + float_pointer, # PL + float_pointer, # PR + float_pointer, # DIF + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # IWORK + _ptr_int, # LIWORK + _ptr_int) # INFO + return functype(addr) + + @classmethod + def numba_ctgsen(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}tgsen' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # IJOB + _ptr_int, # WANTQ + _ptr_int, # WANTZ + _ptr_int, # SELECT + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + float_pointer, # ALPHA + float_pointer, # BETA + float_pointer, # Q + _ptr_int, # LDQ + float_pointer, # Z + _ptr_int, # LDZ + _ptr_int, # M + float_pointer, # PL + float_pointer, # PR + float_pointer, # DIF + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # IWORK + _ptr_int, # LIWORK + _ptr_int) # INFO + return functype(addr) + + @classmethod + def numba_xtrsyl(cls, dtype): + d = _blas_kinds[dtype] + func_name = f'{d}trsyl' + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name) + + functype = ctypes.CFUNCTYPE(None, + _ptr_int, # TRANA + _ptr_int, # TRANB + _ptr_int, # ISGN + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + float_pointer, # C + _ptr_int, # LDC + float_pointer, # SCALE + _ptr_int) # INFO + return functype(addr) diff --git a/numba_scipy/linalg/__init__.py b/numba_scipy/linalg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/numba_scipy/linalg/intrinsics.py b/numba_scipy/linalg/intrinsics.py new file mode 100644 index 0000000..b5f4e1e --- /dev/null +++ b/numba_scipy/linalg/intrinsics.py @@ -0,0 +1,72 @@ +from numba.core import types, cgutils +from numba.extending import intrinsic + + +@intrinsic +def val_to_dptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float64)(types.float64) + return sig, impl + + +@intrinsic +def val_to_zptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.complex128)(types.complex128) + return sig, impl + + +@intrinsic +def val_to_sptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float32)(types.float32) + return sig, impl + + +@intrinsic +def val_to_int_ptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.int32)(types.int32) + return sig, impl + + +@intrinsic +def int_ptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.int32(types.CPointer(types.int32)) + return sig, impl + + +@intrinsic +def dptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float64(types.CPointer(types.float64)) + return sig, impl + + +@intrinsic +def sptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float32(types.CPointer(types.float32)) + return sig, impl diff --git a/numba_scipy/linalg/overloads.py b/numba_scipy/linalg/overloads.py new file mode 100644 index 0000000..e453424 --- /dev/null +++ b/numba_scipy/linalg/overloads.py @@ -0,0 +1,557 @@ +from numba.core import types, cgutils +from numba.extending import overload +from numba.np.linalg import ensure_lapack, _check_finite_matrix, _copy_to_fortran_order, \ + _handle_err_maybe_convergence_problem + +import scipy +import numpy as np +from scipy import linalg + +from numba_scipy.linalg.utilities import _check_scipy_linalg_matrix, _get_underlying_float, _ouc, _iuc, _lhp, _rhp, \ + direct_lyapunov_solution +from numba_scipy.linalg.intrinsics import val_to_int_ptr, int_ptr_to_val +from numba_scipy.linalg.LAPACK import _LAPACK + + +@overload(scipy.linalg.schur) +def schur_impl(A, output): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "schur") + + dtype = A.dtype + w_type = _get_underlying_float(dtype) + + numba_rgees = _LAPACK().numba_rgees(dtype) + numba_cgees = _LAPACK().numba_cgees(dtype) + + def real_schur_impl(A, output): + """ + schur() implementation for real arrays + """ + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + msg = "Last 2 dimensions of the array must be square" + raise linalg.LinAlgError(msg) + + _check_finite_matrix(A) + A_copy = _copy_to_fortran_order(A) + + JOBVS = val_to_int_ptr(ord('V')) + SORT = val_to_int_ptr(ord('N')) + SELECT = val_to_int_ptr(0.0) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + SDIM = val_to_int_ptr(_N) + WR = np.empty(_N, dtype=dtype) + WI = np.empty(_N, dtype=dtype) + _LDVS = _N + LDVS = val_to_int_ptr(_N) + VS = np.empty((_LDVS, _N), dtype=dtype) + LWORK = val_to_int_ptr(-1) + WORK = np.empty(1, dtype=dtype) + BWORK = val_to_int_ptr(1) + INFO = val_to_int_ptr(1) + + # workspace query + numba_rgees(JOBVS, SORT, SELECT, N, A_copy.ctypes, LDA, SDIM, WR.ctypes, WI.ctypes, VS.ctypes, LDVS, + WORK.ctypes, LWORK, BWORK, INFO) + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual work + numba_rgees(JOBVS, SORT, SELECT, N, A_copy.ctypes, LDA, SDIM, WR.ctypes, WI.ctypes, VS.ctypes, LDVS, + WORK.ctypes, LWORK, BWORK, INFO) + + # if np.any(WI) and output == 'complex': + # raise ValueError("schur() argument must not cause a domain change.") + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + + return A_copy, VS.T + + def complex_schur_impl(A, output): + """ + schur() implementation for complex arrays + """ + + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + msg = "Last 2 dimensions of the array must be square" + raise linalg.LinAlgError(msg) + + _check_finite_matrix(A) + A_copy = _copy_to_fortran_order(A) + + JOBVS = val_to_int_ptr(ord('V')) + SORT = val_to_int_ptr(ord('N')) + SELECT = val_to_int_ptr(0.0) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + SDIM = val_to_int_ptr(_N) + W = np.empty(_N, dtype=dtype) + _LDVS = _N + LDVS = val_to_int_ptr(_N) + VS = np.empty((_LDVS, _N), dtype=dtype) + LWORK = val_to_int_ptr(-1) + WORK = np.empty(1, dtype=dtype) + RWORK = np.empty(_N, dtype=w_type) + BWORK = val_to_int_ptr(1) + INFO = val_to_int_ptr(1) + + # workspace query + numba_cgees(JOBVS, SORT, SELECT, N, A_copy.view(w_type).ctypes, LDA, SDIM, W.view(w_type).ctypes, + VS.view(w_type).ctypes, LDVS, WORK.view(w_type).ctypes, LWORK, RWORK.ctypes, BWORK, INFO) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual work + numba_cgees(JOBVS, SORT, SELECT, N, A_copy.view(w_type).ctypes, LDA, SDIM, W.view(w_type).ctypes, + VS.view(w_type).ctypes, LDVS, WORK.view(w_type).ctypes, LWORK, RWORK.ctypes, BWORK, INFO) + + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + + return A_copy, VS.T + + if isinstance(A.dtype, types.scalars.Complex): + return complex_schur_impl + else: + return real_schur_impl + + +def full_return_qz(A, B, output): + pass + + +@overload(full_return_qz) +def full_return_qz_impl(A, B, output): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "qz") + _check_scipy_linalg_matrix(B, "qz") + + dtype = A.dtype + w_type = _get_underlying_float(dtype) + + numba_rgges = _LAPACK().numba_rgges(dtype) + numba_cgges = _LAPACK().numba_cgges(dtype) + + def real_full_return_qz_impl(A, B, output): + """ + schur() implementation for real arrays. Unlike the Scipy function, this has 5 returns, including the + generalized eigenvalues (alpha, beta), because these are required by ordqz. + """ + _M, _N = np.int32(A.shape[-2:]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + if B.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of B must be square") + + _check_finite_matrix(A) + _check_finite_matrix(B) + + A_copy = _copy_to_fortran_order(A) + B_copy = _copy_to_fortran_order(B) + + JOBVSL = val_to_int_ptr(ord('V')) + JOBVSR = val_to_int_ptr(ord('V')) + SORT = val_to_int_ptr(ord('N')) + SELCTG = val_to_int_ptr(1) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + SDIM = val_to_int_ptr(0) + + ALPHAR = np.empty(_N, dtype=dtype) # out + ALPHAI = np.empty(_N, dtype=dtype) # out + BETA = np.empty(_N, dtype=dtype) # out + + _LDVSL = _N + _LDVSR = _N + LDVSL = val_to_int_ptr(_LDVSL) + VSL = np.empty((_LDVSL, _N), dtype=dtype) # out + LDVSR = val_to_int_ptr(_LDVSR) + VSR = np.empty((_LDVSR, _N), dtype=dtype) # out + + WORK = np.empty((1,), dtype=dtype) # out + LWORK = val_to_int_ptr(-1) + BWORK = val_to_int_ptr(1) + INFO = val_to_int_ptr(1) + + # workspace query + numba_rgges(JOBVSL, JOBVSR, SORT, SELCTG, N, A_copy.ctypes, LDA, B_copy.ctypes, LDB, + SDIM, ALPHAR.ctypes, ALPHAI.ctypes, BETA.ctypes, VSL.ctypes, LDVSL, + VSR.ctypes, LDVSR, WORK.ctypes, LWORK, BWORK, INFO) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual work + numba_rgges(JOBVSL, JOBVSR, SORT, SELCTG, N, A_copy.ctypes, LDA, B_copy.ctypes, LDB, + SDIM, ALPHAR.ctypes, ALPHAI.ctypes, BETA.ctypes, VSL.ctypes, LDVSL, + VSR.ctypes, LDVSR, WORK.ctypes, LWORK, BWORK, INFO) + + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + ALPHA = ALPHAR + ALPHAI * 1j + + return A_copy, B_copy, ALPHA, BETA, VSL.T, VSR.T + + def complex_full_return_qz_impl(A, B, output): + """ + qz decomposition for complex arrays. Unlike the Scipy function, this has 5 returns, including the + generalized eigenvalues (alpha, beta), because these are required by ordqz. + """ + + _M, _N = np.int32(A.shape[-2:]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + if B.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of B must be square") + + _check_finite_matrix(A) + _check_finite_matrix(B) + + A_copy = _copy_to_fortran_order(A) + B_copy = _copy_to_fortran_order(B) + + JOBVSL = val_to_int_ptr(ord('V')) + JOBVSR = val_to_int_ptr(ord('V')) + SORT = val_to_int_ptr(ord('N')) + SELCTG = val_to_int_ptr(1) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + SDIM = val_to_int_ptr(0) + + ALPHA = np.empty(_N, dtype=dtype) # out + BETA = np.empty(_N, dtype=dtype) # out + LDVSL = val_to_int_ptr(_N) + VSL = np.empty((_N, _N), dtype=dtype) # out + LDVSR = val_to_int_ptr(_N) + VSR = np.empty((_N, _N), dtype=dtype) # out + + WORK = np.empty((1,), dtype=dtype) # out + LWORK = val_to_int_ptr(-1) + RWORK = np.empty(8 * _N, dtype=w_type) + BWORK = val_to_int_ptr(1) + INFO = val_to_int_ptr(1) + + # workspace query + numba_cgges(JOBVSL, JOBVSR, SORT, SELCTG, N, A_copy.view(w_type).ctypes, LDA, B_copy.view(w_type).ctypes, LDB, + SDIM, ALPHA.view(w_type).ctypes, BETA.view(w_type).ctypes, VSL.view(w_type).ctypes, + LDVSL, VSR.view(w_type).ctypes, LDVSR, WORK.view(w_type).ctypes, LWORK, RWORK.ctypes, BWORK, INFO) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual work + numba_cgges(JOBVSL, JOBVSR, SORT, SELCTG, N, A_copy.view(w_type).ctypes, LDA, B_copy.view(w_type).ctypes, LDB, + SDIM, ALPHA.view(w_type).ctypes, BETA.view(w_type).ctypes, VSL.view(w_type).ctypes, + LDVSL, VSR.view(w_type).ctypes, LDVSR, WORK.view(w_type).ctypes, LWORK, RWORK.ctypes, BWORK, INFO) + + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + + return A_copy, B_copy, ALPHA, BETA, VSL.T, VSR.T + + if isinstance(A.dtype, types.scalars.Complex): + return complex_full_return_qz_impl + else: + return real_full_return_qz_impl + + +@overload(scipy.linalg.qz) +def qz_impl(A, B, output): + """ + scipy.linalg.qz overload. Wraps full_return_qz and returns only A, B, Q ,Z to match the scipy signature. + """ + ensure_lapack() + + _check_scipy_linalg_matrix(A, "qz") + _check_scipy_linalg_matrix(B, "qz") + + def real_qz_impl(A, B, output): + A, B, ALPHA, BETA, VSL, VSR = full_return_qz(A, B, output) + + return A, B, VSL, VSR + + def complex_qz_impl(A, B, output): + A, B, ALPHA, BETA, VSL, VSR = full_return_qz(A, B, output) + return A, B, VSL, VSR + + if isinstance(A.dtype, types.scalars.Complex): + return complex_qz_impl + else: + return real_qz_impl + + +@overload(scipy.linalg.ordqz) +def ordqz_impl(A, B, sort, output): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "ordqz") + _check_scipy_linalg_matrix(B, "ordqz") + + dtype = A.dtype + w_type = _get_underlying_float(dtype) + + numba_rtgsen = _LAPACK().numba_rtgsen(dtype) + numba_ctgsen = _LAPACK().numba_ctgsen(dtype) + + def real_ordqz_impl(A, B, sort, output): + _M, _N = np.int32(A.shape[-2:]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + if B.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of B must be square") + + _check_finite_matrix(A) + _check_finite_matrix(B) + + if sort not in ['lhp', 'rhp', 'iuc', 'ouc']: + raise ValueError('Argument "sort" should be one of: "lhp", "rhp", "iuc", "ouc"') + + A_copy = _copy_to_fortran_order(A) + B_copy = _copy_to_fortran_order(B) + + AA, BB, ALPHA, BETA, Q, Z = full_return_qz(A_copy, B_copy, output) + + if sort == 'lhp': + SELECT = _lhp(ALPHA, BETA) + elif sort == 'rhp': + SELECT = _rhp(ALPHA, BETA) + elif sort == 'iuc': + SELECT = _iuc(ALPHA, BETA) + elif sort == 'ouc': + SELECT = _ouc(ALPHA, BETA) + + IJOB = val_to_int_ptr(0) + WANTQ = val_to_int_ptr(1) + WANTZ = val_to_int_ptr(1) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_M) + LDB = val_to_int_ptr(_M) + + ALPHAR = np.empty(_N, dtype=dtype) + ALPHAI = np.empty(_N, dtype=dtype) + + LDQ = val_to_int_ptr(Q.shape[0]) + LDZ = val_to_int_ptr(Z.shape[0]) + M = val_to_int_ptr(_M) + PL = np.empty(1, dtype=dtype) + PR = np.empty(1, dtype=dtype) + DIF = np.empty(2, dtype=dtype) + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + IWORK = np.empty(1, dtype=np.int32) + LIWORK = val_to_int_ptr(-1) + INFO = val_to_int_ptr(1) + + # workspace query + numba_rtgsen(IJOB, WANTQ, WANTZ, SELECT.ctypes, N, AA.ctypes, LDA, BB.ctypes, LDB, ALPHAR.ctypes, + ALPHAI.ctypes, BETA.ctypes, Q.ctypes, LDQ, Z.ctypes, LDZ, M, PL.ctypes, + PR.ctypes, DIF.ctypes, WORK.ctypes, LWORK, IWORK.ctypes, LIWORK, INFO) + + WS_SIZE = np.int32(WORK[0].real) + IW_SIZE = np.int32(IWORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + LIWORK = val_to_int_ptr(IW_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + IWORK = np.empty(IW_SIZE, dtype=np.int32) + + numba_rtgsen(IJOB, WANTQ, WANTZ, SELECT.ctypes, N, AA.ctypes, LDA, BB.ctypes, LDB, ALPHAR.ctypes, + ALPHAI.ctypes, BETA.ctypes, Q.ctypes, LDQ, Z.ctypes, LDZ, M, PL.ctypes, + PR.ctypes, DIF.ctypes, WORK.ctypes, LWORK, IWORK.ctypes, LIWORK, INFO) + + # if np.any(ALPHAI) and output == 'complex': + # raise ValueError("ordqz() argument must not cause a domain change.") + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + ALPHA = ALPHAR + 1j * ALPHAI + return AA, BB, ALPHA, BETA, Q, Z + + def complex_ordqz_impl(A, B, sort, output): + _M, _N = np.int32(A.shape[-2:]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + if B.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of B must be square") + + _check_finite_matrix(A) + _check_finite_matrix(B) + + if sort not in ['lhp', 'rhp', 'iuc', 'ouc']: + raise ValueError('Argument "sort" should be one of: "lhp", "rhp", "iuc", "ouc"') + + A_copy = _copy_to_fortran_order(A) + B_copy = _copy_to_fortran_order(B) + + AA, BB, ALPHA, BETA, Q, Z = full_return_qz(A_copy, B_copy, output) + + if sort == 'lhp': + SELECT = _lhp(ALPHA, BETA) + elif sort == 'rhp': + SELECT = _rhp(ALPHA, BETA) + elif sort == 'iuc': + SELECT = _iuc(ALPHA, BETA) + elif sort == 'ouc': + SELECT = _ouc(ALPHA, BETA) + + IJOB = val_to_int_ptr(0) + WANTQ = val_to_int_ptr(1) + WANTZ = val_to_int_ptr(1) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_M) + LDB = val_to_int_ptr(_M) + + LDQ = val_to_int_ptr(Q.shape[0]) + LDZ = val_to_int_ptr(Z.shape[0]) + M = val_to_int_ptr(_M) + PL = np.empty(1, dtype=w_type) + PR = np.empty(1, dtype=w_type) + DIF = np.empty(2, dtype=w_type) + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + IWORK = np.empty(1, dtype=np.int32) + LIWORK = val_to_int_ptr(-1) + INFO = val_to_int_ptr(1) + + # workspace query + numba_ctgsen(IJOB, WANTQ, WANTZ, SELECT.ctypes, N, AA.view(w_type).ctypes, LDA, BB.view(w_type).ctypes, LDB, + ALPHA.view(w_type).ctypes, BETA.view(w_type).ctypes, Q.view(w_type).ctypes, LDQ, + Z.view(w_type).ctypes, LDZ, M, PL.ctypes, PR.ctypes, DIF.ctypes, + WORK.view(w_type).ctypes, LWORK, IWORK.ctypes, LIWORK, INFO) + + WS_SIZE = np.int32(WORK[0].real) + IW_SIZE = np.int32(IWORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + LIWORK = val_to_int_ptr(IW_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + IWORK = np.empty(IW_SIZE, dtype=np.int32) + + numba_ctgsen(IJOB, WANTQ, WANTZ, SELECT.ctypes, N, AA.view(w_type).ctypes, LDA, BB.view(w_type).ctypes, + LDB, ALPHA.view(w_type).ctypes, BETA.view(w_type).ctypes, Q.view(w_type).ctypes, LDQ, + Z.view(w_type).ctypes, LDZ, M, PL.ctypes, PR.ctypes, DIF.ctypes, WORK.view(w_type).ctypes, + LWORK, IWORK.ctypes, LIWORK, INFO) + + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + + return AA, BB, ALPHA, BETA, Q, Z + + if isinstance(A.dtype, types.scalars.Complex): + return complex_ordqz_impl + else: + return real_ordqz_impl + + +@overload(scipy.linalg.solve_continuous_lyapunov) +def solve_continuous_lyapunov_impl(A, Q): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "solve_continuous_lyapunov") + _check_scipy_linalg_matrix(Q, "solve_continuous_lyapunov") + + dtype = A.dtype + w_type = _get_underlying_float(dtype) + + numba_xtrsyl = _LAPACK().numba_xtrsyl(dtype) + + def _solve_cont_lyapunov_impl(A, Q): + _M, _N = np.int32(A.shape) + _NQ = np.int32(Q.shape[-1]) + + if _N != _NQ: + raise linalg.LinAlgError('Matrices A and Q must have the same shape') + + if _M != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + if Q.shape[-2] != _NQ: + raise linalg.LinAlgError("Last 2 dimensions of Q must be square") + + _check_finite_matrix(A) + _check_finite_matrix(Q) + + is_complex = (np.iscomplexobj(A) | np.iscomplexobj(Q)) + dtype_letter = 'C' if is_complex else 'T' + output = 'complex' if is_complex else 'real' + + A_copy = _copy_to_fortran_order(A) + Q_copy = _copy_to_fortran_order(Q) + + R, U = linalg.schur(A_copy, output=output) + + # Construct f = u'*q*u + F = U.conj().T.dot(Q_copy.dot(U)) + + TRANA = val_to_int_ptr(ord('N')) + TRANB = val_to_int_ptr(ord(dtype_letter)) + ISGN = val_to_int_ptr(1) + + M = val_to_int_ptr(_N) + N = val_to_int_ptr(_N) + AA = _copy_to_fortran_order(R) + LDA = val_to_int_ptr(_N) + B = _copy_to_fortran_order(R) + LDB = val_to_int_ptr(_N) + C = _copy_to_fortran_order(F) + LDC = val_to_int_ptr(_N) + + # TODO: There is a little bit of overhead here, can I figure out how to assign a + # float or double pointer, depending on the case? + SCALE = np.array(1.0, dtype=w_type) + INFO = val_to_int_ptr(1) + + numba_xtrsyl(TRANA, TRANB, ISGN, M, N, + AA.view(w_type).ctypes, LDA, + B.view(w_type).ctypes, LDB, + C.view(w_type).ctypes, LDC, + SCALE.ctypes, INFO) + + C *= SCALE + _handle_err_maybe_convergence_problem(int_ptr_to_val(INFO)) + X = U.dot(C).dot(U.conj().T) + + return X + + return _solve_cont_lyapunov_impl + + +@overload(scipy.linalg.solve_discrete_lyapunov) +def solve_discrete_lyapunov_impl(A, Q, method='auto'): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "solve_continuous_lyapunov") + _check_scipy_linalg_matrix(Q, "solve_continuous_lyapunov") + + dtype = A.dtype + w_type = _get_underlying_float(dtype) + + def impl(A, Q, method='auto'): + _M, _N = np.int32(A.shape) + + if method == 'auto': + if _M < 10: + method = 'direct' + else: + method = 'bilinear' + + if method == 'direct': + X = direct_lyapunov_solution(A, Q) + + if method == 'bilinear': + eye = np.eye(_M) + AH = A.conj().transpose() + AHI_inv = np.linalg.inv(AH + eye) + B = np.dot(AH - eye, AHI_inv) + C = 2 * np.dot(np.dot(np.linalg.inv(A + eye), Q), AHI_inv) + X = linalg.solve_continuous_lyapunov(B.conj().transpose(), -C) + + return X + + return impl diff --git a/numba_scipy/linalg/utilities.py b/numba_scipy/linalg/utilities.py new file mode 100644 index 0000000..6903c44 --- /dev/null +++ b/numba_scipy/linalg/utilities.py @@ -0,0 +1,99 @@ +from numba import jit, njit, types +from numba.core import types, cgutils +from numba.core.errors import TypingError +import numpy as np + + +def _get_underlying_float(dtype): + s_dtype = str(dtype) + out_type = s_dtype + if s_dtype == 'complex64': + out_type = 'float32' + elif s_dtype == 'complex128': + out_type ='float64' + + return np.dtype(out_type) + + +def _check_scipy_linalg_matrix(a, func_name): + prefix = "scipy.linalg" + interp = (prefix, func_name) + # Unpack optional type + if isinstance(a, types.Optional): + a = a.type + if not isinstance(a, types.Array): + msg = "%s.%s() only supported for array types" % interp + raise TypingError(msg, highlighting=False) + if not a.ndim == 2: + msg = "%s.%s() only supported on 2-D arrays." % interp + raise TypingError(msg, highlighting=False) + if not isinstance(a.dtype, (types.Float, types.Complex)): + msg = "%s.%s() only supported on " \ + "float and complex arrays." % interp + raise TypingError(msg, highlighting=False) + + +@njit +def direct_lyapunov_solution(A, B): + lhs = np.kron(A, A.conj()) + lhs = np.eye(lhs.shape[0]) - lhs + x = np.linalg.solve(lhs, B.flatten()) + + return np.reshape(x, B.shape) + + +@njit +def _lhp(alpha, beta): + out = np.empty(alpha.shape, dtype=np.int32) + nonzero = (beta != 0) + # handles (x, y) = (0, 0) too + out[~nonzero] = False + out[nonzero] = (np.real(alpha[nonzero]/beta[nonzero]) < 0.0) + return out + +@njit +def _rhp(alpha, beta): + out = np.empty(alpha.shape, dtype=np.int32) + nonzero = (beta != 0) + # handles (x, y) = (0, 0) too + out[~nonzero] = False + out[nonzero] = (np.real(alpha[nonzero]/beta[nonzero]) > 0.0) + return out + +@njit +def _iuc(alpha, beta): + out = np.empty(alpha.shape, dtype=np.int32) + nonzero = (beta != 0) + # handles (x, y) = (0, 0) too + out[~nonzero] = False + out[nonzero] = (np.abs(alpha[nonzero]/beta[nonzero]) < 1.0) + + return out + +@njit +def _ouc(alpha, beta): + """ + Jit-aware version of the function scipy.linalg._decomp_qz._ouc, creates the mask needed for ztgsen to sort + eigenvalues from stable to unstable. + + Parameters + ---------- + alpha: Array, complex + alpha vector, as returned by zgges + beta: Array, complex + beta vector, as return by zgges + Returns + ------- + out: Array, bool + Boolean mask indicating which eigenvalues are unstable + """ + + out = np.empty(alpha.shape, dtype=np.int32) + alpha_zero = (alpha == 0) + beta_zero = (beta == 0) + + out[alpha_zero & beta_zero] = False + out[~alpha_zero & beta_zero] = True + out[~beta_zero] = (np.abs(alpha[~beta_zero] / beta[~beta_zero]) > 1.0) + + return out \ No newline at end of file diff --git a/numba_scipy/tests/test_linalg.py b/numba_scipy/tests/test_linalg.py new file mode 100644 index 0000000..0044b95 --- /dev/null +++ b/numba_scipy/tests/test_linalg.py @@ -0,0 +1,476 @@ +import unittest +import numpy as np +from scipy import linalg +from numba import njit + +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_almost_equal, assert_equal, \ + assert_array_equal + +from numba_scipy.linalg.overloads import qz_impl, ordqz_impl, schur_impl, solve_continuous_lyapunov_impl, \ + solve_discrete_lyapunov_impl +from numba_scipy.linalg.utilities import _iuc, _rhp, _ouc, _lhp + + +def make_data(n, dtype): + A = np.random.normal(size=(n, n)).astype(dtype) + B = np.random.normal(size=(n, n)).astype(dtype) + + if 'complex' in dtype: + A += (1j * np.random.normal(size=(n, n))).astype(dtype) + B += (1j * np.random.normal(size=(n, n))).astype(dtype) + + return A.astype(dtype), B.astype(dtype) + + +class numba_schur_test(unittest.TestCase): + def setUp(self) -> None: + @njit + def numba_schur_test(A, output='real'): + return linalg.schur(A, output) + + self.schur = numba_schur_test + + def test_numba_schur_float32_small(self): + n = 5 + A, _ = make_data(n, 'float32') + + t, z = linalg.schur(A) + T, Z = self.schur(A) + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A, decimal=3) + + def test_numba_schur_float32_large(self): + n = 100 + A, _ = make_data(n, 'float32') + + t, z = linalg.schur(A) + T, Z = self.schur(A) + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A, decimal=3) + + def test_numba_schur_float64_small(self): + n = 5 + A, _ = make_data(n, 'float64') + + t, z = linalg.schur(A) + T, Z = self.schur(A) + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A) + + def test_numba_schur_float64_large(self): + n = 100 + A, _ = make_data(n, 'float64') + + t, z = linalg.schur(A) + T, Z = self.schur(A) + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A) + + def test_numba_schur_complex64_small(self): + n = 5 + A, _ = make_data(n, 'complex64') + + t, z = linalg.schur(A) + T, Z = self.schur(A, output='complex') + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A, decimal=3) + + def test_numba_schur_complex64_large(self): + n = 100 + A, _ = make_data(n, 'complex64') + + t, z = linalg.schur(A) + T, Z = self.schur(A, output='complex') + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A, decimal=3) + + def test_numba_schur_complex128_small(self): + n = 5 + A, _ = make_data(n, 'complex128') + + t, z = linalg.schur(A) + T, Z = self.schur(A, output='complex') + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A) + + def test_numba_schur_complex128_large(self): + n = 100 + A, _ = make_data(n, 'complex128') + + t, z = linalg.schur(A) + T, Z = self.schur(A, output='complex') + + assert_allclose(t, T) + assert_allclose(z, Z) + + assert_array_almost_equal(Z @ T @ Z.conj().T, A) + + +class numba_qz_test(unittest.TestCase): + def setUp(self) -> None: + @njit + def numba_qz_test(A, B, output='real'): + return linalg.qz(A, B, output) + + self.qz = numba_qz_test + + def test_numba_qz_float32_small(self): + n = 5 + A, B = make_data(n, 'float32') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.T, A, decimal=5) + assert_array_almost_equal(Q @ BB @ Z.T, B, decimal=5) + assert_array_almost_equal(Q @ Q.T, np.eye(n), decimal=5) + assert_array_almost_equal(Z @ Z.T, np.eye(n), decimal=5) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_float32_large(self): + n = 100 + A, B = make_data(n, 'float32') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.T, A, decimal=3) + assert_array_almost_equal(Q @ BB @ Z.T, B, decimal=3) + assert_array_almost_equal(Q @ Q.T, np.eye(n), decimal=3) + assert_array_almost_equal(Z @ Z.T, np.eye(n), decimal=3) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_float64_small(self): + n = 5 + A, B = make_data(n, 'float64') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.T, A) + assert_array_almost_equal(Q @ BB @ Z.T, B) + assert_array_almost_equal(Q @ Q.T, np.eye(n)) + assert_array_almost_equal(Z @ Z.T, np.eye(n)) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_float64_large(self): + n = 100 + A, B = make_data(n, 'float64') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.T, A) + assert_array_almost_equal(Q @ BB @ Z.T, B) + assert_array_almost_equal(Q @ Q.T, np.eye(n)) + assert_array_almost_equal(Z @ Z.T, np.eye(n)) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_complex64_small(self): + n = 5 + A, B = make_data(n, 'complex64') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.conj().T, A, decimal=5) + assert_array_almost_equal(Q @ BB @ Z.conj().T, B, decimal=5) + assert_array_almost_equal(Q @ Q.conj().T, np.eye(n), decimal=5) + assert_array_almost_equal(Z @ Z.conj().T, np.eye(n), decimal=5) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_complex64_large(self): + n = 100 + A, B = make_data(n, 'complex64') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.conj().T, A, decimal=3) + assert_array_almost_equal(Q @ BB @ Z.conj().T, B, decimal=3) + assert_array_almost_equal(Q @ Q.conj().T, np.eye(n), decimal=3) + assert_array_almost_equal(Z @ Z.conj().T, np.eye(n), decimal=3) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_complex128_small(self): + n = 5 + A, B = make_data(n, 'complex128') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.conj().T, A) + assert_array_almost_equal(Q @ BB @ Z.conj().T, B) + assert_array_almost_equal(Q @ Q.conj().T, np.eye(n)) + assert_array_almost_equal(Z @ Z.conj().T, np.eye(n)) + assert (np.all(np.diag(BB) >= 0)) + + def test_numba_qz_complex128_large(self): + n = 100 + A, B = make_data(n, 'complex128') + aa, bb, q, z = linalg.qz(A, B) + AA, BB, Q, Z = self.qz(A, B) + + assert_allclose(aa, AA) + assert_allclose(bb, BB) + assert_allclose(q, Q) + assert_allclose(z, Z) + + assert_array_almost_equal(Q @ AA @ Z.conj().T, A) + assert_array_almost_equal(Q @ BB @ Z.conj().T, B) + assert_array_almost_equal(Q @ Q.conj().T, np.eye(n)) + assert_array_almost_equal(Z @ Z.conj().T, np.eye(n)) + assert (np.all(np.diag(BB) >= 0)) + + +def _select_function(sort): + if sort == 'lhp': + return _lhp + elif sort == 'rhp': + return _rhp + elif sort == 'iuc': + return _iuc + elif sort == 'ouc': + return _ouc + + +class numba_ordqz_test(unittest.TestCase): + + def setUp(self) -> None: + @njit + def numba_ordqz_test(A, B, sort='lhp', output='real'): + return linalg.ordqz(A, B, sort=sort, output=output) + + self.ordqz = numba_ordqz_test + + def test_ordqz_case_1(self): + A = np.array([[-21.10 - 22.50j, 53.5 - 50.5j, -34.5 + 127.5j, + 7.5 + 0.5j], + [-0.46 - 7.78j, -3.5 - 37.5j, -15.5 + 58.5j, + -10.5 - 1.5j], + [4.30 - 5.50j, 39.7 - 17.1j, -68.5 + 12.5j, + -7.5 - 3.5j], + [5.50 + 4.40j, 14.4 + 43.3j, -32.5 - 46.0j, + -19.0 - 32.5j]], dtype='complex128') + + B = np.array([[1.0 - 5.0j, 1.6 + 1.2j, -3 + 0j, 0.0 - 1.0j], + [0.8 - 0.6j, .0 - 5.0j, -4 + 3j, -2.4 - 3.2j], + [1.0 + 0.0j, 2.4 + 1.8j, -4 - 5j, 0.0 - 3.0j], + [0.0 + 1.0j, -1.8 + 2.4j, 0 - 4j, 4.0 - 5.0j]], dtype='complex128') + + for sort in ['lhp', 'rhp', 'iuc', 'ouc']: + numba_ret = self.ordqz(A, B, output='complex', sort=sort) + scipy_ret = linalg.ordqz(A, B, output='complex', sort=sort) + + for A, a in zip(numba_ret, scipy_ret): + assert_array_almost_equal(A, a) + + def test_ordqz_case_2(self): + A = np.array([[3.9, 12.5, -34.5, -0.5], + [4.3, 21.5, -47.5, 7.5], + [4.3, 21.5, -43.5, 3.5], + [4.4, 26.0, -46.0, 6.0]], dtype='float64') + + B = np.array([[1, 2, -3, 1], + [1, 3, -5, 4], + [1, 3, -4, 3], + [1, 3, -4, 4]], dtype='float64') + + for sort in ['lhp', 'rhp', 'iuc', 'ouc']: + numba_ret = self.ordqz(A, B, output='real', sort=sort) + scipy_ret = linalg.ordqz(A, B, output='real', sort=sort) + + for A, a in zip(numba_ret, scipy_ret): + assert_array_almost_equal(A, a) + + def test_ordqz_case_3(self): + A = np.array([[5., 1., 3., 3.], + [4., 4., 2., 7.], + [7., 4., 1., 3.], + [0., 4., 8., 7.]], dtype='float64') + B = np.array([[8., 10., 6., 10.], + [7., 7., 2., 9.], + [9., 1., 6., 6.], + [5., 1., 4., 7.]], dtype='float64') + + for sort in ['lhp', 'rhp', 'iuc', 'ouc']: + numba_ret = self.ordqz(A, B, output='real', sort=sort) + scipy_ret = linalg.ordqz(A, B, output='real', sort=sort) + + for A, a in zip(numba_ret, scipy_ret): + assert_array_almost_equal(A, a) + + def test_ordqz_case_4(self): + A = np.eye(2).astype('float64') + B = np.diag([0, 1]).astype('float64') + + for sort in ['lhp', 'rhp', 'iuc', 'ouc']: + numba_ret = self.ordqz(A, B, output='real', sort=sort) + scipy_ret = linalg.ordqz(A, B, output='real', sort=sort) + + for A, a in zip(numba_ret, scipy_ret): + assert_array_almost_equal(A, a) + + def test_ordqz_case_5(self): + A = np.diag([1, 0]).astype('float64') + B = np.diag([1, 0]).astype('float64') + + for sort in ['lhp', 'rhp', 'iuc', 'ouc']: + numba_ret = self.ordqz(A, B, output='real', sort=sort) + scipy_ret = linalg.ordqz(A, B, output='real', sort=sort) + + for A, a in zip(numba_ret, scipy_ret): + assert_array_almost_equal(A, a) + + +class numba_solve_lyapunov_tests(unittest.TestCase): + def setUp(self) -> None: + @njit + def numba_solve_continuous_lyapunov(A, B): + return linalg.solve_continuous_lyapunov(A, B) + + @njit + def numba_solve_discrete_lyapunov(A, B, method='auto'): + return linalg.solve_discrete_lyapunov(A, B, method) + + self.solve_continuous_lyapunov = numba_solve_continuous_lyapunov + self.solve_discrete_lyapunov = numba_solve_discrete_lyapunov + + # Numba is much stricter about typing information than vanilla numpy, so type information needed to be + # added to all these test. + self.cases = [ + (np.array([[1, 2], [3, 4]], dtype='float64'), + np.array([[9, 10], [11, 12]], dtype='float64')), + # a, q all complex. + (np.array([[1.0 + 1j, 2.0], [3.0 - 4.0j, 5.0]]), + np.array([[2.0 - 2j, 2.0 + 2j], [-1.0 - 1j, 2.0]])), + # a real; q complex. + (np.array([[1.0, 2.0], [3.0, 5.0]], dtype='complex128'), + np.array([[2.0 - 2j, 2.0 + 2j], [-1.0 - 1j, 2.0]])), + # a complex; q real. + (np.array([[1.0 + 1j, 2.0], [3.0 - 4.0j, 5.0]]), + np.array([[2.0, 2.0], [-1.0, 2.0]], dtype='complex128')), + # An example from Kitagawa, 1977 + (np.array([[3, 9, 5, 1, 4], [1, 2, 3, 8, 4], [4, 6, 6, 6, 3], + [1, 5, 2, 0, 7], [5, 3, 3, 1, 5]], dtype='float64'), + np.array([[2, 4, 1, 0, 1], [4, 1, 0, 2, 0], [1, 0, 3, 0, 3], + [0, 2, 0, 1, 0], [1, 0, 3, 0, 4]], dtype='float64')), + # Companion matrix example. a complex; q real; a.shape[0] = 11 + (np.array([[0.100 + 0.j, 0.091 + 0.j, 0.082 + 0.j, 0.073 + 0.j, 0.064 + 0.j, + 0.055 + 0.j, 0.046 + 0.j, 0.037 + 0.j, 0.028 + 0.j, 0.019 + 0.j, + 0.010 + 0.j], + [1.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 1.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 1.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 1.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 1.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 1.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 1.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 1.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 1.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j], + [0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, + 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 0.000 + 0.j, 1.000 + 0.j, + 0.000 + 0.j]]), + np.eye(11).astype('complex128')), + # https://github.com/scipy/scipy/issues/4176 + (np.matrix([[0, 1], [-1 / 2, -1]], dtype='float64'), + (np.matrix([0, 3], dtype='float64').T @ np.matrix([0, 3], dtype='float64').T.T)), + # https://github.com/scipy/scipy/issues/4176 + (np.matrix([[0, 1], [-1 / 2, -1]], dtype='float64'), + (np.array(np.matrix([0, 3], dtype='float64').T @ np.matrix([0, 3], dtype='float64').T.T))) + ] + + def test_solve_continuous_lyapunov(self): + for (A, B) in self.cases: + X = self.solve_continuous_lyapunov(A, B) + x = linalg.solve_continuous_lyapunov(A, B) + assert_array_almost_equal(X, x) + + def test_solve_discrete_lyapunov_auto(self): + for (A, B) in self.cases: + X = self.solve_discrete_lyapunov(A, B) + x = linalg.solve_discrete_lyapunov(A, B) + assert_array_almost_equal(X, x) + + def test_solve_discrete_lyapunov_direct(self): + for (A, B) in self.cases: + X = self.solve_discrete_lyapunov(A, B, method='direct') + x = linalg.solve_discrete_lyapunov(A, B, method='direct') + assert_array_almost_equal(X, x) + + def test_solve_discrete_lyapunov_bilinear(self): + for (A, B) in self.cases: + X = self.solve_discrete_lyapunov(A, B, method='bilinear') + x = linalg.solve_discrete_lyapunov(A, B, method='bilinear') + assert_array_almost_equal(X, x) + + +if __name__ == '__main__': + unittest.main()