-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add scipy class extensions that use mkl for matmul
- Loading branch information
1 parent
2ef0643
commit 8f181da
Showing
3 changed files
with
158 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from scipy.sparse import ( | ||
csr_array as _sps_csr_array, | ||
csr_matrix as _sps_csr_matrix, | ||
csc_array as _sps_csc_array, | ||
csc_matrix as _sps_csc_matrix, | ||
bsr_array as _sps_bsr_array, | ||
bsr_matrix as _sps_bsr_matrix | ||
) | ||
|
||
from sparse_dot_mkl import dot_product_mkl | ||
|
||
class _mkl_matmul_mixin: | ||
|
||
dense_matmul = False | ||
cast_matmul = True | ||
|
||
def __matmul__(self, other): | ||
return dot_product_mkl( | ||
self, | ||
other, | ||
dense=self.dense_matmul, | ||
cast=self.cast_matmul | ||
) | ||
|
||
def __rmatmul__(self, other): | ||
return dot_product_mkl( | ||
other, | ||
self, | ||
dense=self.dense_matmul, | ||
cast=self.cast_matmul | ||
) | ||
|
||
class csr_array(_mkl_matmul_mixin, _sps_csr_array): | ||
pass | ||
|
||
class csr_matrix(_mkl_matmul_mixin, _sps_csr_matrix): | ||
pass | ||
|
||
class csc_array(_mkl_matmul_mixin, _sps_csc_array): | ||
pass | ||
|
||
class csc_matrix(_mkl_matmul_mixin, _sps_csc_matrix): | ||
pass | ||
|
||
class bsr_array(_mkl_matmul_mixin, _sps_bsr_array): | ||
pass | ||
|
||
class bsr_matrix(_mkl_matmul_mixin, _sps_bsr_matrix): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import unittest | ||
import numpy.testing as npt | ||
import scipy.sparse as sps | ||
from types import MethodType | ||
|
||
from sparse_dot_mkl import dot_product_mkl | ||
from sparse_dot_mkl.sparse_array import ( | ||
csr_array, | ||
csc_array, | ||
bsr_array, | ||
csc_matrix, | ||
csr_matrix, | ||
bsr_matrix | ||
) | ||
from sparse_dot_mkl.tests.test_mkl import MATRIX_1, MATRIX_2, make_matrixes | ||
|
||
MATMUL = MATRIX_1 @ MATRIX_2 | ||
MATMUL = MATMUL.toarray() | ||
|
||
def _tripwire(self, other): | ||
raise RuntimeError("Shouldn't be here") | ||
|
||
def install_wire(x): | ||
x._matmul_dispatch = MethodType(_tripwire, x) | ||
x._rmatmul_dispatch = MethodType(_tripwire, x) | ||
|
||
|
||
class TestCSR(unittest.TestCase): | ||
|
||
arr = csr_array | ||
|
||
def test_matmul(self): | ||
|
||
a = self.arr(MATRIX_1) | ||
b = self.arr(MATRIX_2) | ||
|
||
install_wire(a) | ||
install_wire(b) | ||
|
||
c = a @ b | ||
|
||
npt.assert_almost_equal( | ||
c.toarray(), | ||
MATMUL | ||
) | ||
|
||
def test_matmul_dense(self): | ||
|
||
a = self.arr(MATRIX_1) | ||
b = self.arr(MATRIX_2) | ||
|
||
install_wire(a) | ||
install_wire(b) | ||
|
||
a.dense_matmul = True | ||
a.dense_matmul = True | ||
|
||
c = a @ b | ||
|
||
self.assertFalse( | ||
sps.issparse(c) | ||
) | ||
|
||
npt.assert_almost_equal( | ||
c, | ||
MATMUL | ||
) | ||
|
||
def test_matmul_fail(self): | ||
|
||
a = self.arr(MATRIX_1) | ||
b = self.arr(MATRIX_2) | ||
|
||
with self.assertRaises(ValueError): | ||
b @ a | ||
|
||
m1 = MATRIX_1.copy() | ||
install_wire(m1) | ||
|
||
with self.assertRaises(RuntimeError): | ||
m1 @ MATRIX_2 | ||
|
||
|
||
class TestCSRMat(TestCSR): | ||
arr = csr_matrix | ||
|
||
|
||
class TestCSC(TestCSR): | ||
arr = csc_array | ||
|
||
|
||
class TestCSCMat(TestCSR): | ||
arr = csc_matrix | ||
|
||
|
||
class TestBSRMat(TestCSR): | ||
arr = bsr_matrix | ||
|
||
|
||
class TestBSC(TestCSR): | ||
arr = bsr_array | ||
|