Skip to content

Commit

Permalink
v0.9.6 ISS
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 16, 2024
1 parent e6729e0 commit eff90e6
Show file tree
Hide file tree
Showing 12 changed files with 1,540 additions and 230 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### Version 0.9.6

* Add wrapper for MKL iterative CG and FGMRES solvers in `sparse_dot.solvers`

### Version 0.9.5

* Add wrapper for MKL pardiso solver in `sparse_dot.solvers`
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from setuptools import setup, find_packages

DISTNAME = 'sparse_dot_mkl'
VERSION = '0.9.5'
VERSION = '0.9.6'
DESCRIPTION = "Intel MKL wrapper for sparse matrix multiplication"
MAINTAINER = 'Chris Jackson'
MAINTAINER_EMAIL = '[email protected]'
Expand Down
9 changes: 8 additions & 1 deletion sparse_dot_mkl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.9.5'
__version__ = '0.9.6'


from sparse_dot_mkl.sparse_dot import (
Expand All @@ -19,4 +19,11 @@
mkl_interface_integer_dtype
)

from .solvers import (
pardiso,
pardisoinit,
fgmres,
cg
)

get_version_string = mkl_get_version_string
98 changes: 98 additions & 0 deletions sparse_dot_mkl/_mkl_interface/_cfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,29 @@ class MKL:
_mkl_set_num_threads_local = _libmkl.MKL_Set_Num_Threads_Local
_mkl_get_version = _libmkl.MKL_Get_Version
_mkl_get_version_string = _libmkl.MKL_Get_Version_String
_mkl_free_buffers = _libmkl.mkl_free_buffers

# PARDISO
_pardisoinit = _libmkl.pardisoinit
_pardiso = _libmkl.pardiso

# CG Solver
_dcg_init = _libmkl.dcg_init
_dcg_check = _libmkl.dcg_check
_dcg = _libmkl.dcg
_dcg_get = _libmkl.dcg_get

_dcgmrhs_init = _libmkl.dcgmrhs_init
_dcgmrhs_check = _libmkl.dcgmrhs_check
_dcgmrhs = _libmkl.dcgmrhs
_dcgmrhs_get = _libmkl.dcgmrhs_get

# FGMRES Solver
_dfgmres_init = _libmkl.dfgmres_init
_dfgmres_check = _libmkl.dfgmres_check
_dfgmres = _libmkl.dfgmres
_dfgmres_get = _libmkl.dfgmres_get

@classmethod
def _set_int_type(cls, c_type, _np_type):
cls.MKL_INT = c_type
Expand All @@ -163,6 +181,54 @@ def _set_int_type(cls, c_type, _np_type):
cls._set_int_type_syrk()
cls._set_int_type_misc()
cls._set_int_type_pardiso()
cls._set_int_type_iss()

@classmethod
def _set_int_type_iss(cls):
cls._dcg_init.argtypes = cls._create_iss_argtypes()
cls._dcg_init.restype = None

cls._dcg_check.argtypes = cls._create_iss_argtypes()
cls._dcg_check.restype = None

cls._dcg.argtypes = cls._create_iss_argtypes()
cls._dcg.restype = None

cls._dcg_get.argtypes = cls._create_iss_argtypes() + [
_ctypes.POINTER(MKL.MKL_INT)
]
cls._dcg_get.restype = None

cls._dcgmrhs_init.argtypes = cls._create_iss_mrhs_argtypes(
add_method=True
)
cls._dcgmrhs_init.restype = None

cls._dcgmrhs_check.argtypes = cls._create_iss_mrhs_argtypes()
cls._dcgmrhs_check.restype = None

cls._dfgmres.argtypes = cls._create_iss_mrhs_argtypes()
cls._dfgmres.restype = None

cls._dfgmres_get.argtypes = cls._create_iss_mrhs_argtypes() + [
_ctypes.POINTER(MKL.MKL_INT)
]
cls._dfgmres_get.restype = None

cls._dfgmres_init.argtypes = cls._create_iss_argtypes()
cls._dfgmres_init.restype = None

cls._dfgmres_check.argtypes = cls._create_iss_argtypes()
cls._dfgmres_check.restype = None

cls._dfgmres.argtypes = cls._create_iss_argtypes()
cls._dfgmres.restype = None

cls._dfgmres_get.argtypes = cls._create_iss_argtypes() + [
_ctypes.POINTER(MKL.MKL_INT)
]
cls._dfgmres_get.restype = None


@classmethod
def _set_int_type_pardiso(cls):
Expand Down Expand Up @@ -612,6 +678,32 @@ def _qr_solve(prec_type):
MKL.MKL_INT,
]

@staticmethod
def _create_iss_argtypes():
return [
_ctypes.POINTER(MKL.MKL_INT),
ndpointer(dtype=_ctypes.c_double, ndim=1, flags='C_CONTIGUOUS'),
ndpointer(dtype=_ctypes.c_double, ndim=1, flags='C_CONTIGUOUS'),
_ctypes.POINTER(MKL.MKL_INT),
ndpointer(dtype=MKL.MKL_INT, shape=(128,), flags='C_CONTIGUOUS'),
ndpointer(dtype=_ctypes.c_double, shape=(128,), flags='C_CONTIGUOUS'),
ndpointer(dtype=_ctypes.c_double, flags='C_CONTIGUOUS')
]

@staticmethod
def _create_iss_mrhs_argtypes(add_method=False):
_arg = MKL._create_iss_argtypes()[0:2] + [
_ctypes.POINTER(MKL.MKL_INT)
]
if add_method:
_arg = _arg + MKL._create_iss_argtypes()[2:3] + [
_ctypes.POINTER(MKL.MKL_INT)
] + MKL._create_iss_argtypes()[3:]
else:
_arg = _arg + MKL._create_iss_argtypes()[2:]

return _arg


# Set argtypes and return types for service functions
# not interface dependent
Expand All @@ -630,6 +722,8 @@ def _qr_solve(prec_type):
_ctypes.c_int
]
MKL._mkl_get_version_string.restypes = None
MKL._mkl_free_buffers.argtypes = None
MKL._mkl_free_buffers.restype = None


def mkl_set_interface_layer(layer_code):
Expand Down Expand Up @@ -673,6 +767,10 @@ def mkl_get_version_string():
return c_str.value.decode()


def mkl_free_buffers():
MKL._mkl_free_buffers()


_mkl_interface_env = os.getenv('MKL_INTERFACE_LAYER')


Expand Down
2 changes: 1 addition & 1 deletion sparse_dot_mkl/_mkl_interface/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _create_mkl_sparse(matrix):
"""
Create MKL internal representation
:param matrix: Sparse data in CSR or CSC format
:param matrix: Sparse data in CSR or CSC or BSR format
:type matrix: scipy.sparse.spmatrix
:return ref, double_precision: Handle for the MKL internal representation
Expand Down
Loading

0 comments on commit eff90e6

Please sign in to comment.