From 8f181dacdc34224bb495156355dc7aa62f2beadd Mon Sep 17 00:00:00 2001 From: asistradition Date: Thu, 7 Nov 2024 13:19:03 -0500 Subject: [PATCH] Add scipy class extensions that use mkl for matmul --- sparse_dot_mkl/sparse_array.py | 49 ++++++++++ sparse_dot_mkl/tests/test_pardiso.py | 8 +- sparse_dot_mkl/tests/test_scipy_classes.py | 102 +++++++++++++++++++++ 3 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 sparse_dot_mkl/sparse_array.py create mode 100644 sparse_dot_mkl/tests/test_scipy_classes.py diff --git a/sparse_dot_mkl/sparse_array.py b/sparse_dot_mkl/sparse_array.py new file mode 100644 index 0000000..10a01b7 --- /dev/null +++ b/sparse_dot_mkl/sparse_array.py @@ -0,0 +1,49 @@ +from scipy.sparse import ( + csr_array as _sps_csr_array, + csr_matrix as _sps_csr_matrix, + csc_array as _sps_csc_array, + csc_matrix as _sps_csc_matrix, + bsr_array as _sps_bsr_array, + bsr_matrix as _sps_bsr_matrix +) + +from sparse_dot_mkl import dot_product_mkl + +class _mkl_matmul_mixin: + + dense_matmul = False + cast_matmul = True + + def __matmul__(self, other): + return dot_product_mkl( + self, + other, + dense=self.dense_matmul, + cast=self.cast_matmul + ) + + def __rmatmul__(self, other): + return dot_product_mkl( + other, + self, + dense=self.dense_matmul, + cast=self.cast_matmul + ) + +class csr_array(_mkl_matmul_mixin, _sps_csr_array): + pass + +class csr_matrix(_mkl_matmul_mixin, _sps_csr_matrix): + pass + +class csc_array(_mkl_matmul_mixin, _sps_csc_array): + pass + +class csc_matrix(_mkl_matmul_mixin, _sps_csc_matrix): + pass + +class bsr_array(_mkl_matmul_mixin, _sps_bsr_array): + pass + +class bsr_matrix(_mkl_matmul_mixin, _sps_bsr_matrix): + pass diff --git a/sparse_dot_mkl/tests/test_pardiso.py b/sparse_dot_mkl/tests/test_pardiso.py index d195e34..3ea8593 100644 --- a/sparse_dot_mkl/tests/test_pardiso.py +++ b/sparse_dot_mkl/tests/test_pardiso.py @@ -37,7 +37,13 @@ def test_pardiso_init(self): if self.single_precision: _iparm_init[27] = 1 - npt.assert_equal(self.iparm, _iparm_init) + # Default value changed from 0 (2 iterations implicit) + # to 2 + try: + npt.assert_equal(self.iparm, _iparm_init) + except AssertionError: + _iparm_init[7] = 2 + npt.assert_equal(self.iparm, _iparm_init) def test_pardiso_analysis(self): diff --git a/sparse_dot_mkl/tests/test_scipy_classes.py b/sparse_dot_mkl/tests/test_scipy_classes.py new file mode 100644 index 0000000..b50c5cb --- /dev/null +++ b/sparse_dot_mkl/tests/test_scipy_classes.py @@ -0,0 +1,102 @@ +import unittest +import numpy.testing as npt +import scipy.sparse as sps +from types import MethodType + +from sparse_dot_mkl import dot_product_mkl +from sparse_dot_mkl.sparse_array import ( + csr_array, + csc_array, + bsr_array, + csc_matrix, + csr_matrix, + bsr_matrix +) +from sparse_dot_mkl.tests.test_mkl import MATRIX_1, MATRIX_2, make_matrixes + +MATMUL = MATRIX_1 @ MATRIX_2 +MATMUL = MATMUL.toarray() + +def _tripwire(self, other): + raise RuntimeError("Shouldn't be here") + +def install_wire(x): + x._matmul_dispatch = MethodType(_tripwire, x) + x._rmatmul_dispatch = MethodType(_tripwire, x) + + +class TestCSR(unittest.TestCase): + + arr = csr_array + + def test_matmul(self): + + a = self.arr(MATRIX_1) + b = self.arr(MATRIX_2) + + install_wire(a) + install_wire(b) + + c = a @ b + + npt.assert_almost_equal( + c.toarray(), + MATMUL + ) + + def test_matmul_dense(self): + + a = self.arr(MATRIX_1) + b = self.arr(MATRIX_2) + + install_wire(a) + install_wire(b) + + a.dense_matmul = True + a.dense_matmul = True + + c = a @ b + + self.assertFalse( + sps.issparse(c) + ) + + npt.assert_almost_equal( + c, + MATMUL + ) + + def test_matmul_fail(self): + + a = self.arr(MATRIX_1) + b = self.arr(MATRIX_2) + + with self.assertRaises(ValueError): + b @ a + + m1 = MATRIX_1.copy() + install_wire(m1) + + with self.assertRaises(RuntimeError): + m1 @ MATRIX_2 + + +class TestCSRMat(TestCSR): + arr = csr_matrix + + +class TestCSC(TestCSR): + arr = csc_array + + +class TestCSCMat(TestCSR): + arr = csc_matrix + + +class TestBSRMat(TestCSR): + arr = bsr_matrix + + +class TestBSC(TestCSR): + arr = bsr_array +