From 59bcb1a079ec47fffa302ca5883caa37fa0165fa Mon Sep 17 00:00:00 2001 From: Benjamin Desef Date: Wed, 6 Nov 2024 12:57:00 +0100 Subject: [PATCH] Fix #28 --- sparse_dot_mkl/_dense_dense.py | 11 +++++++--- sparse_dot_mkl/_gram_matrix.py | 8 ++++++-- sparse_dot_mkl/_mkl_interface/_common.py | 26 +++++++++++++----------- sparse_dot_mkl/_sparse_dense.py | 17 ++++++++-------- sparse_dot_mkl/_sparse_sparse.py | 9 +++++--- sparse_dot_mkl/_sparse_sypr.py | 4 ++-- sparse_dot_mkl/_sparse_vector.py | 22 ++++++++++++-------- 7 files changed, 58 insertions(+), 39 deletions(-) diff --git a/sparse_dot_mkl/_dense_dense.py b/sparse_dot_mkl/_dense_dense.py index 7730f3e..55f9354 100644 --- a/sparse_dot_mkl/_dense_dense.py +++ b/sparse_dot_mkl/_dense_dense.py @@ -48,7 +48,7 @@ def _dense_matmul(matrix_a, matrix_b, scalar=1., out=None, out_scalar=None): # The complex versions of these functions take void pointers instead of passed structs # So create a C struct if necessary to be passed by reference scalar = _mkl_scalar(scalar, complex_type, double_precision) - out_scalar = _mkl_scalar(out_scalar, complex_type, double_precision) + out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_precision) func(layout_a, 111, @@ -75,8 +75,13 @@ def _dense_dot_dense(matrix_a, matrix_b, cast=False, scalar=1., out=None, out_sc # Check for edge condition inputs which result in empty outputs if _empty_output_check(matrix_a, matrix_b): debug_print("Skipping multiplication because A (dot) B must yield an empty matrix") - final_dtype = np.float64 if matrix_a.dtype != matrix_b.dtype or matrix_a.dtype != np.float32 else np.float32 - return _out_matrix((matrix_a.shape[0], matrix_b.shape[1]), final_dtype, out_arr=out) + output_arr = _out_matrix((matrix_a.shape[0], matrix_b.shape[1]), + _type_check(matrix_a, matrix_b, cast=cast, convert=False), out_arr=out) + if out is None or (out_scalar is not None and not out_scalar): + output_arr.fill(0) + elif out_scalar is not None: + output_arr *= out_scalar + return output_arr matrix_a, matrix_b = _type_check(matrix_a, matrix_b, cast=cast) diff --git a/sparse_dot_mkl/_gram_matrix.py b/sparse_dot_mkl/_gram_matrix.py index 79e556a..cbf6813 100644 --- a/sparse_dot_mkl/_gram_matrix.py +++ b/sparse_dot_mkl/_gram_matrix.py @@ -141,10 +141,14 @@ def _gram_matrix_sparse_to_dense( if _empty_output_check(matrix_a, matrix_a): _destroy_mkl_handle(sp_ref_a) + if out is None or (out_scalar is not None and not out_scalar): + output_arr.fill(0) + elif out_scalar is not None: + output_arr *= out_scalar return output_arr scalar = _mkl_scalar(scalar, complex_type, double_prec) - out_scalar = _mkl_scalar(out_scalar, complex_type, double_prec) + out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_prec) ret_val = func( _mkl_sp_transpose_ops[(not aat, complex_type)], @@ -230,7 +234,7 @@ def _gram_matrix_dense_to_dense( # passed structs, so create a C struct if necessary to be passed by # reference scalar = _mkl_scalar(scalar, complex_type, double_precision) - out_scalar = _mkl_scalar(out_scalar, complex_type, double_precision) + out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_precision) func( layout_a, diff --git a/sparse_dot_mkl/_mkl_interface/_common.py b/sparse_dot_mkl/_mkl_interface/_common.py index e9d5893..74e4854 100644 --- a/sparse_dot_mkl/_mkl_interface/_common.py +++ b/sparse_dot_mkl/_mkl_interface/_common.py @@ -770,12 +770,14 @@ def _is_valid_dtype(matrix, complex_dtype=False, all_dtype=False): return matrix.dtype in NUMPY_FLOAT_DTYPES -def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): +def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True, convert=True): """ Make sure that both matrices are single precision floats or both are double precision floats. If not, convert to double precision floats if cast is True, or raise an error if cast is False + If convert is set to False, the resulting data type is returned without + any conversion happening. """ _n_complex = _np.iscomplexobj(matrix_a) + _np.iscomplexobj(matrix_b) @@ -785,17 +787,17 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): # If there's no matrix B and matrix A is valid dtype, return it if matrix_b is None and _is_valid_dtype(matrix_a, all_dtype=True): - return matrix_a + return matrix_a if convert else matrix_a.dtype # If matrix A is complex but not csingle or cdouble, and cast is True, # convert it to a cdouble elif matrix_b is None and cast and _n_complex == 1: - return _cast_to(matrix_a, _np.cdouble) + return _cast_to(matrix_a, _np.cdouble) if convert else _np.cdouble # If matrix A is real but not float32 or float64, and cast is True, # convert it to a float64 elif matrix_b is None and cast: - return _cast_to(matrix_a, _np.float64) + return _cast_to(matrix_a, _np.float64) if convert else _np.float64 # Raise an error - the dtype is invalid and cast is False elif matrix_b is None: @@ -809,7 +811,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): _is_valid_dtype(matrix_a, all_dtype=True) and matrix_a.dtype == matrix_b.dtype ): - return matrix_a, matrix_b + return (matrix_a, matrix_b) if convert else matrix_a.dtype # If neither matrix is complex and cast is True, convert to float64s # and return them @@ -818,7 +820,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): f"Recasting matrix data types {matrix_a.dtype} and " f"{matrix_b.dtype} to np.float64" ) - return _cast_to(matrix_a, _np.float64), _cast_to(matrix_b, _np.float64) + return (_cast_to(matrix_a, _np.float64), _cast_to(matrix_b, _np.float64)) if convert else _np.float64 # If both matrices are complex and cast is True, convert to cdoubles # and return them @@ -827,7 +829,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): f"Recasting matrix data types {matrix_a.dtype} and " f"{matrix_b.dtype} to _np.cdouble" ) - return _cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble) + return (_cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)) if convert else _np.cdouble # Cast reals and complex matrices together elif ( @@ -838,7 +840,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): debug_print( f"Recasting matrix data type {matrix_b.dtype} to {matrix_a.dtype}" ) - return matrix_a, _cast_to(matrix_b, matrix_a.dtype) + return (matrix_a, _cast_to(matrix_b, matrix_a.dtype)) if convert else matrix_a.dtype elif ( cast and @@ -848,14 +850,14 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True): debug_print( f"Recasting matrix data type {matrix_a.dtype} to {matrix_b.dtype}" ) - return _cast_to(matrix_a, matrix_b.dtype), matrix_b + return (_cast_to(matrix_a, matrix_b.dtype), matrix_b) if convert else matrix_b.dtype elif cast and _n_complex == 1: debug_print( f"Recasting matrix data type {matrix_a.dtype} and {matrix_b.dtype}" f" to np.cdouble" ) - return _cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble) + return (_cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)) if convert else _np.cdouble # If cast is False, can't cast anything together elif not cast: @@ -884,7 +886,7 @@ def _mkl_scalar(scalar, complex_type, double_precision): def _out_matrix(shape, dtype, order="C", out_arr=None, out_t=False): """ - Create an all-zero matrix or check to make sure that + Create an undefined matrix or check to make sure that the provided output array matches :param shape: Required output shape @@ -905,7 +907,7 @@ def _out_matrix(shape, dtype, order="C", out_arr=None, out_t=False): # If there's no output array allocate a new array and return it if out_arr is None: - return _np.zeros(shape, dtype=dtype, order=order) + return _np.ndarray(shape, dtype=dtype, order=order) # Check and make sure the order is correct # Note 1d arrays have both flags set diff --git a/sparse_dot_mkl/_sparse_dense.py b/sparse_dot_mkl/_sparse_dense.py index c4200cf..958feca 100644 --- a/sparse_dot_mkl/_sparse_dense.py +++ b/sparse_dot_mkl/_sparse_dense.py @@ -106,7 +106,7 @@ def _sparse_dense_matmul( # Create a C struct if necessary to be passed scalar = _mkl_scalar(scalar, cplx, dbl) - out_scalar = _mkl_scalar(out_scalar, cplx, dbl) + out_scalar = _mkl_scalar(0 if out is None else out_scalar, cplx, dbl) ret_val = func( 11 if transpose else 10, @@ -165,14 +165,15 @@ def _sparse_dot_dense( debug_print( "Skipping multiplication because A (dot) B must yield empty matrix" ) - final_dtype = ( - np.float64 - if matrix_a.dtype != matrix_b.dtype or matrix_a.dtype != np.float32 - else np.float32 - ) - return _out_matrix( - (matrix_a.shape[0], matrix_b.shape[1]), final_dtype, out_arr=out + output_arr = _out_matrix( + (matrix_a.shape[0], matrix_b.shape[1]), + _type_check(matrix_a, matrix_b, cast=cast, convert=False), out_arr=out ) + if out is None or (out_scalar is not None and not out_scalar): + output_arr.fill(0) + elif out_scalar is not None: + output_arr *= out_scalar + return output_arr matrix_a, matrix_b = _type_check(matrix_a, matrix_b, cast=cast) diff --git a/sparse_dot_mkl/_sparse_sparse.py b/sparse_dot_mkl/_sparse_sparse.py index 5fd70ab..c7523b4 100644 --- a/sparse_dot_mkl/_sparse_sparse.py +++ b/sparse_dot_mkl/_sparse_sparse.py @@ -169,16 +169,19 @@ def _sparse_dot_sparse( # Check for edge condition inputs which result in empty outputs if _empty_output_check(matrix_a, matrix_b): + final_dtype = _type_check(matrix_a, matrix_b, cast=cast, convert=False) if dense: - return _out_matrix( + output_arr = _out_matrix( (matrix_a.shape[0], matrix_b.shape[1]), - matrix_a.dtype, + final_dtype, out_arr=out ) + output_arr.fill(0) + return output_arr else: return default_output( (matrix_a.shape[0], matrix_b.shape[1]), - dtype=matrix_a.dtype + dtype=final_dtype ) # Check dtypes diff --git a/sparse_dot_mkl/_sparse_sypr.py b/sparse_dot_mkl/_sparse_sypr.py index 054bdff..71109a8 100644 --- a/sparse_dot_mkl/_sparse_sypr.py +++ b/sparse_dot_mkl/_sparse_sypr.py @@ -69,8 +69,8 @@ def _sypr_sparse_A_dense_B( matrix_b, layout_b, ld_b, - float(out_scalar) if a_scalar is not None else 1.0, - float(out_scalar) if out_scalar is not None else 1.0, + float(a_scalar) if a_scalar is not None else 1.0, + 0.0 if out is None else (float(out_scalar) if out_scalar is not None else 1.0), output_arr, output_layout, output_ld, diff --git a/sparse_dot_mkl/_sparse_vector.py b/sparse_dot_mkl/_sparse_vector.py index f28202d..ed1e653 100644 --- a/sparse_dot_mkl/_sparse_vector.py +++ b/sparse_dot_mkl/_sparse_vector.py @@ -56,14 +56,6 @@ def _sparse_dense_vector_mult( output_shape = matrix_a.shape[1] if transpose else matrix_a.shape[0] output_shape = (output_shape,) if vector_b.ndim == 1 else (output_shape, 1) - if _empty_output_check(matrix_a, vector_b): - final_dtype = ( - np.float64 - if matrix_a.dtype != vector_b.dtype or matrix_a.dtype != np.float32 - else np.float32 - ) - return _out_matrix(output_shape, final_dtype, out_arr=out) - mkl_a, dbl, cplx = _create_mkl_sparse(matrix_a) vector_b = vector_b.ravel() @@ -75,7 +67,7 @@ def _sparse_dense_vector_mult( # Create a C struct if necessary to be passed scalar = _mkl_scalar(scalar, cplx, dbl) - out_scalar = _mkl_scalar(out_scalar, cplx, dbl) + out_scalar = _mkl_scalar(0 if out is None else out_scalar, cplx, dbl) output_arr = _out_matrix( output_shape, @@ -135,6 +127,18 @@ def _sparse_dot_vector( """ _sanity_check(mv_a, mv_b, allow_vector=True) + + if _empty_output_check(mv_a, mv_b): + output_arr = _out_matrix( + (mv_a.shape[0],) if mv_b.ndim == 1 else (mv_a.shape[0], 1), + _type_check(mv_a, mv_b, cast=cast, convert=False), out_arr=out + ) + if out is None or (out_scalar is not None and not out_scalar): + output_arr.fill(0) + elif out_scalar is not None: + output_arr *= out_scalar + return output_arr + mv_a, mv_b = _type_check(mv_a, mv_b, cast=cast) if (