Skip to content

Commit

Permalink
Add scipy class extensions that use mkl for matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Nov 7, 2024
1 parent 2ef0643 commit 8f181da
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 1 deletion.
49 changes: 49 additions & 0 deletions sparse_dot_mkl/sparse_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from scipy.sparse import (
csr_array as _sps_csr_array,
csr_matrix as _sps_csr_matrix,
csc_array as _sps_csc_array,
csc_matrix as _sps_csc_matrix,
bsr_array as _sps_bsr_array,
bsr_matrix as _sps_bsr_matrix
)

from sparse_dot_mkl import dot_product_mkl

class _mkl_matmul_mixin:

dense_matmul = False
cast_matmul = True

def __matmul__(self, other):
return dot_product_mkl(
self,
other,
dense=self.dense_matmul,
cast=self.cast_matmul
)

def __rmatmul__(self, other):
return dot_product_mkl(
other,
self,
dense=self.dense_matmul,
cast=self.cast_matmul
)

class csr_array(_mkl_matmul_mixin, _sps_csr_array):
pass

class csr_matrix(_mkl_matmul_mixin, _sps_csr_matrix):
pass

class csc_array(_mkl_matmul_mixin, _sps_csc_array):
pass

class csc_matrix(_mkl_matmul_mixin, _sps_csc_matrix):
pass

class bsr_array(_mkl_matmul_mixin, _sps_bsr_array):
pass

class bsr_matrix(_mkl_matmul_mixin, _sps_bsr_matrix):
pass
8 changes: 7 additions & 1 deletion sparse_dot_mkl/tests/test_pardiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ def test_pardiso_init(self):
if self.single_precision:
_iparm_init[27] = 1

npt.assert_equal(self.iparm, _iparm_init)
# Default value changed from 0 (2 iterations implicit)
# to 2
try:
npt.assert_equal(self.iparm, _iparm_init)
except AssertionError:
_iparm_init[7] = 2
npt.assert_equal(self.iparm, _iparm_init)

def test_pardiso_analysis(self):

Expand Down
102 changes: 102 additions & 0 deletions sparse_dot_mkl/tests/test_scipy_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
import numpy.testing as npt
import scipy.sparse as sps
from types import MethodType

from sparse_dot_mkl import dot_product_mkl
from sparse_dot_mkl.sparse_array import (
csr_array,
csc_array,
bsr_array,
csc_matrix,
csr_matrix,
bsr_matrix
)
from sparse_dot_mkl.tests.test_mkl import MATRIX_1, MATRIX_2, make_matrixes

MATMUL = MATRIX_1 @ MATRIX_2
MATMUL = MATMUL.toarray()

def _tripwire(self, other):
raise RuntimeError("Shouldn't be here")

def install_wire(x):
x._matmul_dispatch = MethodType(_tripwire, x)
x._rmatmul_dispatch = MethodType(_tripwire, x)


class TestCSR(unittest.TestCase):

arr = csr_array

def test_matmul(self):

a = self.arr(MATRIX_1)
b = self.arr(MATRIX_2)

install_wire(a)
install_wire(b)

c = a @ b

npt.assert_almost_equal(
c.toarray(),
MATMUL
)

def test_matmul_dense(self):

a = self.arr(MATRIX_1)
b = self.arr(MATRIX_2)

install_wire(a)
install_wire(b)

a.dense_matmul = True
a.dense_matmul = True

c = a @ b

self.assertFalse(
sps.issparse(c)
)

npt.assert_almost_equal(
c,
MATMUL
)

def test_matmul_fail(self):

a = self.arr(MATRIX_1)
b = self.arr(MATRIX_2)

with self.assertRaises(ValueError):
b @ a

m1 = MATRIX_1.copy()
install_wire(m1)

with self.assertRaises(RuntimeError):
m1 @ MATRIX_2


class TestCSRMat(TestCSR):
arr = csr_matrix


class TestCSC(TestCSR):
arr = csc_array


class TestCSCMat(TestCSR):
arr = csc_matrix


class TestBSRMat(TestCSR):
arr = bsr_matrix


class TestBSC(TestCSR):
arr = bsr_array

0 comments on commit 8f181da

Please sign in to comment.