Skip to content

Commit

Permalink
Fix #28
Browse files Browse the repository at this point in the history
  • Loading branch information
projekter committed Nov 6, 2024
1 parent 5ea1658 commit 59bcb1a
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 39 deletions.
11 changes: 8 additions & 3 deletions sparse_dot_mkl/_dense_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _dense_matmul(matrix_a, matrix_b, scalar=1., out=None, out_scalar=None):
# The complex versions of these functions take void pointers instead of passed structs
# So create a C struct if necessary to be passed by reference
scalar = _mkl_scalar(scalar, complex_type, double_precision)
out_scalar = _mkl_scalar(out_scalar, complex_type, double_precision)
out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_precision)

func(layout_a,
111,
Expand All @@ -75,8 +75,13 @@ def _dense_dot_dense(matrix_a, matrix_b, cast=False, scalar=1., out=None, out_sc
# Check for edge condition inputs which result in empty outputs
if _empty_output_check(matrix_a, matrix_b):
debug_print("Skipping multiplication because A (dot) B must yield an empty matrix")
final_dtype = np.float64 if matrix_a.dtype != matrix_b.dtype or matrix_a.dtype != np.float32 else np.float32
return _out_matrix((matrix_a.shape[0], matrix_b.shape[1]), final_dtype, out_arr=out)
output_arr = _out_matrix((matrix_a.shape[0], matrix_b.shape[1]),
_type_check(matrix_a, matrix_b, cast=cast, convert=False), out_arr=out)
if out is None or (out_scalar is not None and not out_scalar):
output_arr.fill(0)
elif out_scalar is not None:
output_arr *= out_scalar
return output_arr

matrix_a, matrix_b = _type_check(matrix_a, matrix_b, cast=cast)

Expand Down
8 changes: 6 additions & 2 deletions sparse_dot_mkl/_gram_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,14 @@ def _gram_matrix_sparse_to_dense(

if _empty_output_check(matrix_a, matrix_a):
_destroy_mkl_handle(sp_ref_a)
if out is None or (out_scalar is not None and not out_scalar):
output_arr.fill(0)
elif out_scalar is not None:
output_arr *= out_scalar
return output_arr

scalar = _mkl_scalar(scalar, complex_type, double_prec)
out_scalar = _mkl_scalar(out_scalar, complex_type, double_prec)
out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_prec)

ret_val = func(
_mkl_sp_transpose_ops[(not aat, complex_type)],
Expand Down Expand Up @@ -230,7 +234,7 @@ def _gram_matrix_dense_to_dense(
# passed structs, so create a C struct if necessary to be passed by
# reference
scalar = _mkl_scalar(scalar, complex_type, double_precision)
out_scalar = _mkl_scalar(out_scalar, complex_type, double_precision)
out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_precision)

func(
layout_a,
Expand Down
26 changes: 14 additions & 12 deletions sparse_dot_mkl/_mkl_interface/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,12 +770,14 @@ def _is_valid_dtype(matrix, complex_dtype=False, all_dtype=False):
return matrix.dtype in NUMPY_FLOAT_DTYPES


def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True, convert=True):
"""
Make sure that both matrices are single precision floats or both are
double precision floats.
If not, convert to double precision floats if cast is True,
or raise an error if cast is False
If convert is set to False, the resulting data type is returned without
any conversion happening.
"""

_n_complex = _np.iscomplexobj(matrix_a) + _np.iscomplexobj(matrix_b)
Expand All @@ -785,17 +787,17 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):

# If there's no matrix B and matrix A is valid dtype, return it
if matrix_b is None and _is_valid_dtype(matrix_a, all_dtype=True):
return matrix_a
return matrix_a if convert else matrix_a.dtype

# If matrix A is complex but not csingle or cdouble, and cast is True,
# convert it to a cdouble
elif matrix_b is None and cast and _n_complex == 1:
return _cast_to(matrix_a, _np.cdouble)
return _cast_to(matrix_a, _np.cdouble) if convert else _np.cdouble

# If matrix A is real but not float32 or float64, and cast is True,
# convert it to a float64
elif matrix_b is None and cast:
return _cast_to(matrix_a, _np.float64)
return _cast_to(matrix_a, _np.float64) if convert else _np.float64

# Raise an error - the dtype is invalid and cast is False
elif matrix_b is None:
Expand All @@ -809,7 +811,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
_is_valid_dtype(matrix_a, all_dtype=True) and
matrix_a.dtype == matrix_b.dtype
):
return matrix_a, matrix_b
return (matrix_a, matrix_b) if convert else matrix_a.dtype

# If neither matrix is complex and cast is True, convert to float64s
# and return them
Expand All @@ -818,7 +820,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
f"Recasting matrix data types {matrix_a.dtype} and "
f"{matrix_b.dtype} to np.float64"
)
return _cast_to(matrix_a, _np.float64), _cast_to(matrix_b, _np.float64)
return (_cast_to(matrix_a, _np.float64), _cast_to(matrix_b, _np.float64)) if convert else _np.float64

# If both matrices are complex and cast is True, convert to cdoubles
# and return them
Expand All @@ -827,7 +829,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
f"Recasting matrix data types {matrix_a.dtype} and "
f"{matrix_b.dtype} to _np.cdouble"
)
return _cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)
return (_cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)) if convert else _np.cdouble

# Cast reals and complex matrices together
elif (
Expand All @@ -838,7 +840,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
debug_print(
f"Recasting matrix data type {matrix_b.dtype} to {matrix_a.dtype}"
)
return matrix_a, _cast_to(matrix_b, matrix_a.dtype)
return (matrix_a, _cast_to(matrix_b, matrix_a.dtype)) if convert else matrix_a.dtype

elif (
cast and
Expand All @@ -848,14 +850,14 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
debug_print(
f"Recasting matrix data type {matrix_a.dtype} to {matrix_b.dtype}"
)
return _cast_to(matrix_a, matrix_b.dtype), matrix_b
return (_cast_to(matrix_a, matrix_b.dtype), matrix_b) if convert else matrix_b.dtype

elif cast and _n_complex == 1:
debug_print(
f"Recasting matrix data type {matrix_a.dtype} and {matrix_b.dtype}"
f" to np.cdouble"
)
return _cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)
return (_cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)) if convert else _np.cdouble

# If cast is False, can't cast anything together
elif not cast:
Expand Down Expand Up @@ -884,7 +886,7 @@ def _mkl_scalar(scalar, complex_type, double_precision):

def _out_matrix(shape, dtype, order="C", out_arr=None, out_t=False):
"""
Create an all-zero matrix or check to make sure that
Create an undefined matrix or check to make sure that
the provided output array matches
:param shape: Required output shape
Expand All @@ -905,7 +907,7 @@ def _out_matrix(shape, dtype, order="C", out_arr=None, out_t=False):

# If there's no output array allocate a new array and return it
if out_arr is None:
return _np.zeros(shape, dtype=dtype, order=order)
return _np.ndarray(shape, dtype=dtype, order=order)

# Check and make sure the order is correct
# Note 1d arrays have both flags set
Expand Down
17 changes: 9 additions & 8 deletions sparse_dot_mkl/_sparse_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _sparse_dense_matmul(

# Create a C struct if necessary to be passed
scalar = _mkl_scalar(scalar, cplx, dbl)
out_scalar = _mkl_scalar(out_scalar, cplx, dbl)
out_scalar = _mkl_scalar(0 if out is None else out_scalar, cplx, dbl)

ret_val = func(
11 if transpose else 10,
Expand Down Expand Up @@ -165,14 +165,15 @@ def _sparse_dot_dense(
debug_print(
"Skipping multiplication because A (dot) B must yield empty matrix"
)
final_dtype = (
np.float64
if matrix_a.dtype != matrix_b.dtype or matrix_a.dtype != np.float32
else np.float32
)
return _out_matrix(
(matrix_a.shape[0], matrix_b.shape[1]), final_dtype, out_arr=out
output_arr = _out_matrix(
(matrix_a.shape[0], matrix_b.shape[1]),
_type_check(matrix_a, matrix_b, cast=cast, convert=False), out_arr=out
)
if out is None or (out_scalar is not None and not out_scalar):
output_arr.fill(0)
elif out_scalar is not None:
output_arr *= out_scalar
return output_arr

matrix_a, matrix_b = _type_check(matrix_a, matrix_b, cast=cast)

Expand Down
9 changes: 6 additions & 3 deletions sparse_dot_mkl/_sparse_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,19 @@ def _sparse_dot_sparse(
# Check for edge condition inputs which result in empty outputs
if _empty_output_check(matrix_a, matrix_b):

final_dtype = _type_check(matrix_a, matrix_b, cast=cast, convert=False)
if dense:
return _out_matrix(
output_arr = _out_matrix(
(matrix_a.shape[0], matrix_b.shape[1]),
matrix_a.dtype,
final_dtype,
out_arr=out
)
output_arr.fill(0)
return output_arr
else:
return default_output(
(matrix_a.shape[0], matrix_b.shape[1]),
dtype=matrix_a.dtype
dtype=final_dtype
)

# Check dtypes
Expand Down
4 changes: 2 additions & 2 deletions sparse_dot_mkl/_sparse_sypr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _sypr_sparse_A_dense_B(
matrix_b,
layout_b,
ld_b,
float(out_scalar) if a_scalar is not None else 1.0,
float(out_scalar) if out_scalar is not None else 1.0,
float(a_scalar) if a_scalar is not None else 1.0,
0.0 if out is None else (float(out_scalar) if out_scalar is not None else 1.0),
output_arr,
output_layout,
output_ld,
Expand Down
22 changes: 13 additions & 9 deletions sparse_dot_mkl/_sparse_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ def _sparse_dense_vector_mult(
output_shape = matrix_a.shape[1] if transpose else matrix_a.shape[0]
output_shape = (output_shape,) if vector_b.ndim == 1 else (output_shape, 1)

if _empty_output_check(matrix_a, vector_b):
final_dtype = (
np.float64
if matrix_a.dtype != vector_b.dtype or matrix_a.dtype != np.float32
else np.float32
)
return _out_matrix(output_shape, final_dtype, out_arr=out)

mkl_a, dbl, cplx = _create_mkl_sparse(matrix_a)
vector_b = vector_b.ravel()

Expand All @@ -75,7 +67,7 @@ def _sparse_dense_vector_mult(

# Create a C struct if necessary to be passed
scalar = _mkl_scalar(scalar, cplx, dbl)
out_scalar = _mkl_scalar(out_scalar, cplx, dbl)
out_scalar = _mkl_scalar(0 if out is None else out_scalar, cplx, dbl)

output_arr = _out_matrix(
output_shape,
Expand Down Expand Up @@ -135,6 +127,18 @@ def _sparse_dot_vector(
"""

_sanity_check(mv_a, mv_b, allow_vector=True)

if _empty_output_check(mv_a, mv_b):
output_arr = _out_matrix(
(mv_a.shape[0],) if mv_b.ndim == 1 else (mv_a.shape[0], 1),
_type_check(mv_a, mv_b, cast=cast, convert=False), out_arr=out
)
if out is None or (out_scalar is not None and not out_scalar):
output_arr.fill(0)
elif out_scalar is not None:
output_arr *= out_scalar
return output_arr

mv_a, mv_b = _type_check(mv_a, mv_b, cast=cast)

if (
Expand Down

0 comments on commit 59bcb1a

Please sign in to comment.