From 8d5ddf14e08e9006e5d97ac8ec070f1e340e3c31 Mon Sep 17 00:00:00 2001 From: Benjamin Desef Date: Wed, 6 Nov 2024 14:10:12 +0100 Subject: [PATCH] Some more Gram matrix emptyness safety --- sparse_dot_mkl/_gram_matrix.py | 41 +++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/sparse_dot_mkl/_gram_matrix.py b/sparse_dot_mkl/_gram_matrix.py index cbf6813..2bbea66 100644 --- a/sparse_dot_mkl/_gram_matrix.py +++ b/sparse_dot_mkl/_gram_matrix.py @@ -283,6 +283,22 @@ def _gram_matrix( :rtype: scipy.sparse.csr_matrix, np.ndarray """ + if _sps.issparse(matrix) and not (is_csr(matrix) or is_csc(matrix)): + raise ValueError( + "gram_matrix requires sparse matrix to be CSR or CSC format" + ) + elif is_csc(matrix) and not cast: + raise ValueError( + "gram_matrix cannot use a CSC matrix unless cast=True" + ) + elif out is not None and not dense: + raise ValueError( + "out argument cannot be used with sparse (dot) sparse " + "matrix multiplication" + ) + elif out is not None and not isinstance(out, np.ndarray): + raise ValueError("out argument must be dense") + # Check for edge condition inputs which result in empty outputs if _empty_output_check(matrix, matrix): debug_print( @@ -294,8 +310,14 @@ def _gram_matrix( if transpose else (matrix.shape[0], matrix.shape[0]) ) - output_func = _sps.csr_matrix if _sps.isspmatrix(matrix) else np.zeros - return output_func(output_shape, dtype=matrix.dtype) + if out is None: + output_func = np.zeros if dense else _sps.csr_matrix + return output_func(output_shape, dtype=matrix.dtype) + elif out_scalar is not None and not out_scalar: + out.fill(0) + elif out_scalar is not None: + out *= out_scalar + return out if np.iscomplexobj(matrix): raise ValueError( @@ -304,15 +326,7 @@ def _gram_matrix( matrix = _type_check(matrix, cast=cast) - if _sps.issparse(matrix) and not (is_csr(matrix) or is_csc(matrix)): - raise ValueError( - "gram_matrix requires sparse matrix to be CSR or CSC format" - ) - elif is_csc(matrix) and not cast: - raise ValueError( - "gram_matrix cannot use a CSC matrix unless cast=True" - ) - elif not _sps.issparse(matrix): + if not _sps.issparse(matrix): return _gram_matrix_dense_to_dense( matrix, aat=transpose, @@ -326,11 +340,6 @@ def _gram_matrix( out=out, out_scalar=out_scalar ) - elif out is not None: - raise ValueError( - "out argument cannot be used with sparse (dot) sparse " - "matrix multiplication" - ) else: return _gram_matrix_sparse( matrix,