Skip to content

Commit

Permalink
Some more Gram matrix emptyness safety
Browse files Browse the repository at this point in the history
  • Loading branch information
projekter committed Nov 6, 2024
1 parent 59bcb1a commit 8d5ddf1
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions sparse_dot_mkl/_gram_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,22 @@ def _gram_matrix(
:rtype: scipy.sparse.csr_matrix, np.ndarray
"""

if _sps.issparse(matrix) and not (is_csr(matrix) or is_csc(matrix)):
raise ValueError(
"gram_matrix requires sparse matrix to be CSR or CSC format"
)
elif is_csc(matrix) and not cast:
raise ValueError(
"gram_matrix cannot use a CSC matrix unless cast=True"
)
elif out is not None and not dense:
raise ValueError(
"out argument cannot be used with sparse (dot) sparse "
"matrix multiplication"
)
elif out is not None and not isinstance(out, np.ndarray):
raise ValueError("out argument must be dense")

# Check for edge condition inputs which result in empty outputs
if _empty_output_check(matrix, matrix):
debug_print(
Expand All @@ -294,8 +310,14 @@ def _gram_matrix(
if transpose
else (matrix.shape[0], matrix.shape[0])
)
output_func = _sps.csr_matrix if _sps.isspmatrix(matrix) else np.zeros
return output_func(output_shape, dtype=matrix.dtype)
if out is None:
output_func = np.zeros if dense else _sps.csr_matrix
return output_func(output_shape, dtype=matrix.dtype)
elif out_scalar is not None and not out_scalar:
out.fill(0)
elif out_scalar is not None:
out *= out_scalar
return out

if np.iscomplexobj(matrix):
raise ValueError(
Expand All @@ -304,15 +326,7 @@ def _gram_matrix(

matrix = _type_check(matrix, cast=cast)

if _sps.issparse(matrix) and not (is_csr(matrix) or is_csc(matrix)):
raise ValueError(
"gram_matrix requires sparse matrix to be CSR or CSC format"
)
elif is_csc(matrix) and not cast:
raise ValueError(
"gram_matrix cannot use a CSC matrix unless cast=True"
)
elif not _sps.issparse(matrix):
if not _sps.issparse(matrix):
return _gram_matrix_dense_to_dense(
matrix,
aat=transpose,
Expand All @@ -326,11 +340,6 @@ def _gram_matrix(
out=out,
out_scalar=out_scalar
)
elif out is not None:
raise ValueError(
"out argument cannot be used with sparse (dot) sparse "
"matrix multiplication"
)
else:
return _gram_matrix_sparse(
matrix,
Expand Down

0 comments on commit 8d5ddf1

Please sign in to comment.