Skip to content

Commit

Permalink
Add wrapper for PARDISO
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 3, 2024
1 parent 5d60893 commit d12521d
Show file tree
Hide file tree
Showing 5 changed files with 411 additions and 12 deletions.
34 changes: 34 additions & 0 deletions sparse_dot_mkl/_mkl_interface/_cfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ class MKL:
_mkl_get_version = _libmkl.MKL_Get_Version
_mkl_get_version_string = _libmkl.MKL_Get_Version_String

# PARDISO
_pardisoinit = _libmkl.pardisoinit
_pardiso = _libmkl.pardiso

@classmethod
def _set_int_type(cls, c_type, _np_type):
cls.MKL_INT = c_type
Expand All @@ -158,6 +162,36 @@ def _set_int_type(cls, c_type, _np_type):
cls._set_int_type_qr_solver()
cls._set_int_type_syrk()
cls._set_int_type_misc()
cls._set_int_type_pardiso()

@classmethod
def _set_int_type_pardiso(cls):
cls._pardiso.argtypes = [
ndpointer(shape=(64,), dtype=_np.int64), #pt
_ctypes.POINTER(MKL.MKL_INT), #maxfct
_ctypes.POINTER(MKL.MKL_INT), #mnum
_ctypes.POINTER(MKL.MKL_INT), #mtype
_ctypes.POINTER(MKL.MKL_INT), #phase
_ctypes.POINTER(MKL.MKL_INT), #n
ndpointer(ndim=1), #a
ndpointer(dtype=MKL.MKL_INT, ndim=1), #ia
ndpointer(dtype=MKL.MKL_INT, ndim=1), #ja
ndpointer(dtype=MKL.MKL_INT, ndim=1), #perm
_ctypes.POINTER(MKL.MKL_INT), #nrhs
ndpointer(shape=(64,), dtype=MKL.MKL_INT), #iparm
_ctypes.POINTER(MKL.MKL_INT), #msglvl
ndpointer(), #b
ndpointer(flags="C_CONTIGUOUS"), #x
_ctypes.POINTER(MKL.MKL_INT) #error
]
cls._pardiso.restype = None

cls._pardisoinit.argtypes = [
ndpointer(shape=(64,), dtype=_np.int64),
_ctypes.POINTER(MKL.MKL_INT),
ndpointer(shape=(64,), dtype=MKL.MKL_INT_NUMPY)
]
cls._pardisoinit.restype = None

@classmethod
def _set_int_type_create(cls):
Expand Down
15 changes: 9 additions & 6 deletions sparse_dot_mkl/_mkl_interface/_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from ._cfunctions import MKL, mkl_library_name
from ._cfunctions import (
MKL,
mkl_library_name,
mkl_get_version_string,
mkl_get_max_threads
)

from ._constants import (
LAYOUT_CODE_C,
LAYOUT_CODE_F,
Expand Down Expand Up @@ -107,11 +113,8 @@ def print_mkl_debug():
if not MKL.MKL_DEBUG:
return

if get_version_string() is None:
print("mkl-service must be installed to get full debug messaging")
else:
print(get_version_string())
print(f"MKL Number of Threads: {get_max_threads()}")
print(mkl_get_version_string())
print(f"MKL Number of Threads: {mkl_get_max_threads()}")

print(f"MKL linked: {mkl_library_name()}")
print(f"MKL interface {MKL.MKL_INT_NUMPY} | {MKL.MKL_INT}")
Expand Down
13 changes: 7 additions & 6 deletions sparse_dot_mkl/_sparse_qr_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
matrix_descr,
_convert_to_csr,
_check_return_value,
LAYOUT_CODE_C
LAYOUT_CODE_C,
is_csc,
is_csr
)

import numpy as np
import ctypes as _ctypes
import scipy.sparse as _spsparse

# Keyed by bool for double-precision
SOLVE_FUNCS = {
Expand Down Expand Up @@ -52,7 +53,7 @@ def _sparse_qr(
output_shape = matrix_a.shape[1], matrix_b.shape[1]

# Convert a CSC matrix to CSR
if _spsparse.isspmatrix_csc(matrix_a):
if is_csc(matrix_a):
mkl_a = _convert_to_csr(mkl_a)
_mkl_handles.append(mkl_a)

Expand Down Expand Up @@ -127,14 +128,14 @@ def sparse_qr_solver(
:rtype: numpy.ndarray
"""

if _spsparse.isspmatrix_csc(matrix_a) and not cast:
if is_csc(matrix_a) and not cast:
raise ValueError(
"sparse_qr_solver only accepts CSR matrices if cast=False"
)

elif (
not _spsparse.isspmatrix_csr(matrix_a) and
not _spsparse.isspmatrix_csc(matrix_a)
not is_csc(matrix_a) and
not is_csr(matrix_a)
):
raise ValueError(
"sparse_qr_solver requires matrix A to be CSR or CSC sparse matrix"
Expand Down
224 changes: 224 additions & 0 deletions sparse_dot_mkl/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import numpy as np
import ctypes as _ctypes
import scipy.sparse as sps

import warnings

from sparse_dot_mkl._mkl_interface._cfunctions import (
MKL
)
from sparse_dot_mkl._mkl_interface._common import (
is_csr
)

PARDISO_ERRORS = {
0: None,
-1: "input inconsistent",
-2: "not enough memory",
-3: "reordering problem",
-4: "Zero pivot, numerical factorization or iterative refinement problem",
-5: "unclassified (internal) error",
-6: "reordering failed (matrix types 11 and 13 only)",
-7: "diagonal matrix is singular",
-8: "32-bit integer overflow problem",
-9: "not enough memory for OOC",
-10: "error opening OOC files",
-11: "read/write error with OOC files",
-12: "(pardiso_64 only) pardiso_64 called from 32-bit library",
-13: "interrupted by the (user-defined) mkl_progress function",
-15: "internal error which can appear for iparm[23]=10 and iparm[12]=1"
}

def pardiso(
A,
B,
pt,
mtype,
iparm,
phase,
maxfct=1,
mnum=1,
perm=None,
msglvl=0,
X=None,
quiet=False
):
"""
Run pardiso solver for AX = B
:param A: Matrix A in CSR format
:type A: sp.sparse.csr_array, sp.sparse.csr_matrix
:param B: Matrix B in dense format
:type B: np.ndarray
:param pt: Pointer array, shape=(64,) dtype=int64
:type pt: np.ndarray
:param mtype: Matrix type:
1 Real and structurally symmetric
2 Real and symmetric positive definite
-2 Real and symmetric indefinite
3 Complex and structurally symmetric
4 Complex and Hermitian positive definite
-4 Complex and Hermitian indefinite
6 Complex and symmetric matrix
11 Real and nonsymmetric matrix
13 Complex and nonsymmetric matrix
:type mtype: int
:param iparm: Solver parameters array, shape=(64,)
:type iparm: np.ndarray
:param phase: Solver phase
:type phase: int
:param maxfct: Pardiso maxfct, defaults to 1
:type maxfct: int, optional
:param mnum: Pardiso mnum, defaults to 1
:type mnum: int, optional
:param perm: Permutation vector array, new allocation if None,
defaults to None
:type perm: np.ndarray, optional
:param msglvl: Pardiso message level, defaults to 0
:type msglvl: int, optional
:param X: Solved array X, new allocation if None,
defaults to None
:type X: np.ndarray, optional
:param quiet: Don't issue runtime warnings if pardiso
returnvalue != 0, defaults to False
:type quiet: bool, optional
:return:
Solved array X,
Pointer array pt,
Permutation array perm,
Return value error
:rtype: np.ndarray, np.ndarray, np.ndarray, int
"""

if not is_csr(A):
raise ValueError(
f'A must be a CSR matrix; {type(A)} passed'
)

if sps.issparse(B):
raise ValueError(
f'B must be a dense array; {type(B)} passed'
)

if A.shape[0] != B.shape[0]:
raise ValueError(
f"Bad matrix shapes for AX=B solver: "
f"A {A.shape} & B {B.shape}"
)
else:
N = A.shape[0]

if perm is None:
perm = np.zeros(N, dtype=MKL.MKL_INT_NUMPY)

if B.ndim == 1:
nrhs = 1
elif B.ndim > 2:
raise ValueError('B must be 1- or 2-d')
else:
nrhs = B.shape[1]

if X is None:
X = np.zeros_like(B)

error = MKL.MKL_INT(0)

MKL._pardiso(
pt,
_ctypes.byref(MKL.MKL_INT(maxfct)),
_ctypes.byref(MKL.MKL_INT(mnum)),
_ctypes.byref(MKL.MKL_INT(mtype)),
_ctypes.byref(MKL.MKL_INT(phase)),
_ctypes.byref(MKL.MKL_INT(N)),
A.data,
A.indptr.astype(MKL.MKL_INT_NUMPY),
A.indices.astype(MKL.MKL_INT_NUMPY),
perm,
_ctypes.byref(MKL.MKL_INT(nrhs)),
iparm,
_ctypes.byref(MKL.MKL_INT(msglvl)),
B,
X,
_ctypes.byref(error)
)

error = error.value

if error != 0 and not quiet:
warnings.warn(
f"MKL pardiso error {error}: " +
PARDISO_ERRORS[error],
RuntimeWarning
)

return X, pt, perm, error


def pardisoinit(
mtype,
pt=None,
iparm=None,
single_precision=None,
zero_indexing=True
):
"""
Run pardisoinit to initialize pt and iparm for
a given matrix type
:param mtype: Matrix type:
1 Real and structurally symmetric
2 Real and symmetric positive definite
-2 Real and symmetric indefinite
3 Complex and structurally symmetric
4 Complex and Hermitian positive definite
-4 Complex and Hermitian indefinite
6 Complex and symmetric matrix
11 Real and nonsymmetric matrix
13 Complex and nonsymmetric matrix
:type mtype: int
:param pt: Pointer array (int64), new allocation if None,
defaults to None
:type pt: np.ndarray, optional
:param iparm: Solver parameters array, new allocation if None,
defaults to None
:type iparm: np.ndarray, optional
:param single_precision: Set iparm flag for single precision if True,
set flag for double precision if False, do not change flag value in
iparm if None, defaults to None
:type single_precision: bool, optional
:param zero_indexing: Set iparm flag for zero indexing (C & python),
if True, set flag for one indexing (F) if False, do not change flag
value in iparm if None, defaults to True
:type zero_indexing: bool, optional
:return: pt (pointer) and iparm (parameter) arrays for pardiso
:rtype: np.ndarray, np.ndarray
"""

if pt is None:
pt = np.empty(64, np.int64)

if iparm is None:
iparm = np.zeros(64, dtype=MKL.MKL_INT_NUMPY)

MKL._pardisoinit(
pt,
_ctypes.byref(MKL.MKL_INT(mtype)),
iparm
)

# Set zero indexing flag in iparm[34]
if zero_indexing is None:
pass
elif zero_indexing:
iparm[34] = 1
else:
iparm[34] = 0

if single_precision is None:
pass
elif single_precision:
iparm[27] = 1
else:
iparm[27] = 0

return pt, iparm
Loading

0 comments on commit d12521d

Please sign in to comment.