diff --git a/CMakeLists.txt b/CMakeLists.txt
index 76f5aedc2..d681945cb 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -62,6 +62,7 @@ option(ENABLE_PORTFFT_BACKEND "Enable the portFFT DFT backend for the DFT interf
# sparse
option(ENABLE_CUSPARSE_BACKEND "Enable the cuSPARSE backend for the SPARSE_BLAS interface" OFF)
+option(ENABLE_ROCSPARSE_BACKEND "Enable the rocSPARSE backend for the SPARSE_BLAS interface" OFF)
set(ONEMKL_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler")
set(HIP_TARGETS "" CACHE STRING "Target HIP architectures")
@@ -106,7 +107,8 @@ if(ENABLE_MKLGPU_BACKEND
endif()
if(ENABLE_MKLCPU_BACKEND
OR ENABLE_MKLGPU_BACKEND
- OR ENABLE_CUSPARSE_BACKEND)
+ OR ENABLE_CUSPARSE_BACKEND
+ OR ENABLE_ROCSPARSE_BACKEND)
list(APPEND DOMAINS_LIST "sparse_blas")
endif()
@@ -134,7 +136,7 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++")
endif()
else()
if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_CUSPARSE_BACKEND
- OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND)
+ OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND OR ENABLE_ROCSPARSE_BACKEND)
set(CMAKE_CXX_COMPILER "clang++")
elseif(ENABLE_MKLGPU_BACKEND)
if(UNIX)
diff --git a/README.md b/README.md
index dc023c67c..c1a2c5207 100644
--- a/README.md
+++ b/README.md
@@ -18,8 +18,8 @@ oneMKL is part of the [UXL Foundation](http://www.uxlfoundation.org).
- oneMKL interface |
- oneMKL selector |
+ oneMKL interface |
+ oneMKL selector |
Intel(R) oneAPI Math Kernel Library (oneMKL) |
x86 CPU, Intel GPU |
@@ -61,7 +61,11 @@ oneMKL is part of the [UXL Foundation](http://www.uxlfoundation.org).
AMD GPU |
- AMD rocFFT |
+ AMD rocFFT |
+ AMD GPU |
+
+
+ AMD rocSPARSE |
AMD GPU |
@@ -333,7 +337,7 @@ Supported compilers include:
Dynamic, Static |
- SPARSE_BLAS |
+ SPARSE_BLAS |
x86 CPU |
Intel(R) oneMKL |
Intel DPC++ |
@@ -351,6 +355,12 @@ Supported compilers include:
Open DPC++ |
Dynamic, Static |
+
+ AMD GPU |
+ AMD rocSPARSE |
+ Open DPC++ |
+ Dynamic, Static |
+
@@ -537,6 +547,7 @@ Product | Supported Version | License
[AMD rocRAND](https://github.com/ROCm/rocRAND) | 5.1.0 | [AMD License](https://github.com/ROCm/rocRAND/blob/develop/LICENSE.txt)
[AMD rocSOLVER](https://github.com/ROCm/rocSOLVER) | 5.0.0 | [AMD License](https://github.com/ROCm/rocSOLVER/blob/develop/LICENSE.md)
[AMD rocFFT](https://github.com/ROCm/rocFFT) | rocm-5.4.3 | [AMD License](https://github.com/ROCm/rocFFT/blob/rocm-5.4.3/LICENSE.md)
+[AMD rocSPARSE](https://github.com/ROCm/rocSPARSE) | 3.1.2 | [AMD License](https://github.com/ROCm/rocSPARSE/blob/develop/LICENSE.md)
[NETLIB LAPACK](https://www.netlib.org/) | [5d4180c](https://github.com/Reference-LAPACK/lapack/commit/5d4180cf8288ae6ad9a771d18793d15bd0c5643c) | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt)
[portBLAS](https://github.com/codeplaysoftware/portBLAS) | 0.1 | [Apache License v2.0](https://github.com/codeplaysoftware/portBLAS/blob/main/LICENSE)
[portFFT](https://github.com/codeplaysoftware/portFFT) | 0.1 | [Apache License v2.0](https://github.com/codeplaysoftware/portFFT/blob/main/LICENSE)
diff --git a/cmake/FindCompiler.cmake b/cmake/FindCompiler.cmake
index 8aefc2623..f8c35ee5c 100644
--- a/cmake/FindCompiler.cmake
+++ b/cmake/FindCompiler.cmake
@@ -43,7 +43,7 @@ if(is_dpcpp)
list(APPEND UNIX_INTERFACE_LINK_OPTIONS
-fsycl-targets=nvptx64-nvidia-cuda)
elseif(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND
- OR ENABLE_ROCSOLVER_BACKEND)
+ OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND)
list(APPEND UNIX_INTERFACE_COMPILE_OPTIONS
-fsycl-targets=amdgcn-amd-amdhsa -fsycl-unnamed-lambda
-Xsycl-target-backend --offload-arch=${HIP_TARGETS})
@@ -52,7 +52,7 @@ if(is_dpcpp)
--offload-arch=${HIP_TARGETS})
endif()
if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUSPARSE_BACKEND OR ENABLE_ROCBLAS_BACKEND
- OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND)
+ OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND)
set_target_properties(ONEMKL::SYCL::SYCL PROPERTIES
INTERFACE_COMPILE_OPTIONS "${UNIX_INTERFACE_COMPILE_OPTIONS}"
INTERFACE_LINK_OPTIONS "${UNIX_INTERFACE_LINK_OPTIONS}"
@@ -69,7 +69,7 @@ if(is_dpcpp)
INTERFACE_LINK_LIBRARIES ${SYCL_LIBRARY})
endif()
- if(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND)
+ if(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND)
# Allow find_package(HIP) to find the correct path to libclang_rt.builtins.a
# HIP's CMake uses the command `${HIP_CXX_COMPILER} -print-libgcc-file-name --rtlib=compiler-rt` to find this path.
# This can print a non-existing file if the compiler used is icpx.
diff --git a/docs/building_the_project_with_dpcpp.rst b/docs/building_the_project_with_dpcpp.rst
index efe92f285..2ad6d6504 100644
--- a/docs/building_the_project_with_dpcpp.rst
+++ b/docs/building_the_project_with_dpcpp.rst
@@ -122,6 +122,9 @@ The most important supported build options are:
* - ENABLE_ROCRAND_BACKEND
- True, False
- False
+ * - ENABLE_ROCSPARSE_BACKEND
+ - True, False
+ - False
* - ENABLE_MKLCPU_THREAD_TBB
- True, False
- True
@@ -198,14 +201,14 @@ Building for ROCm
^^^^^^^^^^^^^^^^^
The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND``,
-``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND`` and
-``ENABLE_ROCRAND_BACKEND``.
+``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND``,
+``ENABLE_ROCRAND_BACKEND``, and ``ENABLE_ROCSPARSE_BACKEND``.
-For *RocBLAS*, *RocSOLVER* and *RocRAND*, the target device architecture must be
-set. This can be set with using the ``HIP_TARGETS`` parameter. For example, to
-enable a build for MI200 series GPUs, ``-DHIP_TARGETS=gfx90a`` should be set.
-Currently, DPC++ can only build for a single HIP target at a time. This may
-change in future versions.
+For *RocBLAS*, *RocSOLVER*, *RocRAND*, and *RocSPARSE*, the target device
+architecture must be set. This can be set with using the ``HIP_TARGETS``
+parameter. For example, to enable a build for MI200 series GPUs,
+``-DHIP_TARGETS=gfx90a`` should be set. Currently, DPC++ can only build for a
+single HIP target at a time. This may change in future versions.
A few often-used architectures are listed below:
@@ -394,7 +397,8 @@ disabled:
-DENABLE_MKLGPU_BACKEND=False \
-DENABLE_ROCFFT_BACKEND=True \
-DENABLE_ROCBLAS_BACKEND=True \
- -DENABLE_ROCSOLVER_BACKEND=True \
+ -DENABLE_ROCSOLVER_BACKEND=True \
+ -DENABLE_ROCSPARSE_BACKEND=True \
-DHIP_TARGETS=gfx90a \
-DBUILD_FUNCTIONAL_TESTS=False
diff --git a/docs/domains/sparse_linear_algebra.rst b/docs/domains/sparse_linear_algebra.rst
index 07d90359a..adfd0ca98 100644
--- a/docs/domains/sparse_linear_algebra.rst
+++ b/docs/domains/sparse_linear_algebra.rst
@@ -68,6 +68,31 @@ Currently known limitations:
``cusparseSpMV_preprocess``. Feel free to create an issue if this is needed.
+rocSPARSE backend
+----------------
+
+Currently known limitations:
+
+- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will
+ throw a ``oneapi::mkl::unimplemented`` exception.
+- The COO format requires the indices to be sorted by row then by column. See
+ the `rocSPARSE COO documentation
+ `_.
+ Sparse operations using matrices with the COO format without the property
+ ``matrix_property::sorted`` will throw a ``oneapi::mkl::unimplemented``
+ exception.
+- The CSR format requires the column indices to be sorted within each row. See
+ the `rocSPARSE CSR documentation
+ `_.
+ Sparse operations using matrices with the CSR format without the property
+ ``matrix_property::sorted`` will throw a ``oneapi::mkl::unimplemented``
+ exception.
+- The same sparse matrix handle cannot be reused for multiple operations
+ ``spmm``, ``spmv``, or ``spsv``. Doing so will throw a
+ ``oneapi::mkl::unimplemented`` exception. See `#332
+ `_.
+
+
Operation algorithms mapping
----------------------------
@@ -89,33 +114,43 @@ spmm
* - ``spmm_alg`` value
- MKLCPU/MKLGPU
- cuSPARSE
+ - rocSPARSE
* - ``default_alg``
- none
- ``CUSPARSE_SPMM_ALG_DEFAULT``
+ - ``rocsparse_spmm_alg_default``
* - ``no_optimize_alg``
- none
- ``CUSPARSE_SPMM_ALG_DEFAULT``
+ - ``rocsparse_spmm_alg_default``
* - ``coo_alg1``
- none
- ``CUSPARSE_SPMM_COO_ALG1``
+ - ``rocsparse_spmm_alg_coo_segmented``
* - ``coo_alg2``
- none
- ``CUSPARSE_SPMM_COO_ALG2``
+ - ``rocsparse_spmm_alg_coo_atomic``
* - ``coo_alg3``
- none
- ``CUSPARSE_SPMM_COO_ALG3``
+ - ``rocsparse_spmm_alg_coo_segmented_atomic``
* - ``coo_alg4``
- none
- ``CUSPARSE_SPMM_COO_ALG4``
+ - ``rocsparse_spmm_alg_default``
* - ``csr_alg1``
- none
- ``CUSPARSE_SPMM_CSR_ALG1``
+ - ``rocsparse_spmm_alg_csr``
* - ``csr_alg2``
- none
- ``CUSPARSE_SPMM_CSR_ALG2``
+ - ``rocsparse_spmm_alg_csr_row_split``
* - ``csr_alg3``
- none
- ``CUSPARSE_SPMM_CSR_ALG3``
+ - ``rocsparse_spmm_alg_csr_merge``
spmv
@@ -128,27 +163,35 @@ spmv
* - ``spmv_alg`` value
- MKLCPU/MKLGPU
- cuSPARSE
+ - rocSPARSE
* - ``default_alg``
- none
- ``CUSPARSE_SPMV_ALG_DEFAULT``
+ - ``rocsparse_spmv_alg_default``
* - ``no_optimize_alg``
- none
- ``CUSPARSE_SPMV_ALG_DEFAULT``
+ - ``rocsparse_spmv_alg_default``
* - ``coo_alg1``
- none
- ``CUSPARSE_SPMV_COO_ALG1``
+ - ``rocsparse_spmv_alg_coo``
* - ``coo_alg2``
- none
- ``CUSPARSE_SPMV_COO_ALG2``
+ - ``rocsparse_spmv_alg_coo_atomic``
* - ``csr_alg1``
- none
- ``CUSPARSE_SPMV_CSR_ALG1``
+ - ``rocsparse_spmv_alg_csr_adaptive``
* - ``csr_alg2``
- none
- ``CUSPARSE_SPMV_CSR_ALG2``
+ - ``rocsparse_spmv_alg_csr_stream``
* - ``csr_alg3``
- none
- ``CUSPARSE_SPMV_ALG_DEFAULT``
+ - ``rocsparse_spmv_alg_csr_lrb``
spsv
@@ -161,9 +204,12 @@ spsv
* - ``spsv_alg`` value
- MKLCPU/MKLGPU
- cuSPARSE
+ - rocSPARSE
* - ``default_alg``
- none
- ``CUSPARSE_SPSV_ALG_DEFAULT``
+ - ``rocsparse_spsv_alg_default``
* - ``no_optimize_alg``
- none
- ``CUSPARSE_SPSV_ALG_DEFAULT``
+ - ``rocsparse_spsv_alg_default``
diff --git a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt
index f09daf819..fb425ef16 100644
--- a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt
+++ b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt
@@ -36,6 +36,9 @@ endif()
if(ENABLE_CUSPARSE_BACKEND)
list(APPEND DEVICE_FILTERS "cuda:gpu")
endif()
+if(ENABLE_ROCSPARSE_BACKEND)
+ list(APPEND DEVICE_FILTERS "hip:gpu")
+endif()
message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples")
diff --git a/include/oneapi/mkl/detail/backends.hpp b/include/oneapi/mkl/detail/backends.hpp
index 216a6feba..3b30ff417 100644
--- a/include/oneapi/mkl/detail/backends.hpp
+++ b/include/oneapi/mkl/detail/backends.hpp
@@ -41,6 +41,7 @@ enum class backend {
rocfft,
portfft,
cusparse,
+ rocsparse,
unsupported
};
@@ -63,6 +64,7 @@ static backendmap backend_map = { { backend::mklcpu, "mklcpu" },
{ backend::rocfft, "rocfft" },
{ backend::portfft, "portfft" },
{ backend::cusparse, "cusparse" },
+ { backend::rocsparse, "rocsparse" },
{ backend::unsupported, "unsupported" } };
// clang-format on
diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp
index 9b7c921d6..81eb6d5a0 100644
--- a/include/oneapi/mkl/detail/backends_table.hpp
+++ b/include/oneapi/mkl/detail/backends_table.hpp
@@ -204,6 +204,12 @@ static std::map>> libraries =
{
#ifdef ONEMKL_ENABLE_CUSPARSE_BACKEND
LIB_NAME("sparse_blas_cusparse")
+#endif
+ } },
+ { device::amdgpu,
+ {
+#ifdef ONEMKL_ENABLE_ROCSPARSE_BACKEND
+ LIB_NAME("sparse_blas_rocsparse")
#endif
} } } },
};
diff --git a/include/oneapi/mkl/sparse_blas.hpp b/include/oneapi/mkl/sparse_blas.hpp
index 8fb86f244..7562bb7a3 100644
--- a/include/oneapi/mkl/sparse_blas.hpp
+++ b/include/oneapi/mkl/sparse_blas.hpp
@@ -37,6 +37,9 @@
#ifdef ONEMKL_ENABLE_CUSPARSE_BACKEND
#include "sparse_blas/detail/cusparse/sparse_blas_ct.hpp"
#endif
+#ifdef ONEMKL_ENABLE_ROCSPARSE_BACKEND
+#include "sparse_blas/detail/rocsparse/sparse_blas_ct.hpp"
+#endif
#include "sparse_blas/detail/sparse_blas_rt.hpp"
diff --git a/include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp b/include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp
new file mode 100644
index 000000000..951978921
--- /dev/null
+++ b/include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp
@@ -0,0 +1,33 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMKL_SPARSE_BLAS_ROCSPARSE_HPP_
+#define _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMKL_SPARSE_BLAS_ROCSPARSE_HPP_
+
+#include "oneapi/mkl/detail/export.hpp"
+#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp"
+#include "oneapi/mkl/sparse_blas/types.hpp"
+
+namespace oneapi::mkl::sparse::rocsparse {
+
+#include "oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx"
+
+} // namespace oneapi::mkl::sparse::rocsparse
+
+#endif // _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMKL_SPARSE_BLAS_ROCSPARSE_HPP_
diff --git a/include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp
new file mode 100644
index 000000000..645230fa6
--- /dev/null
+++ b/include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp
@@ -0,0 +1,40 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_
+#define _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_
+
+#include "oneapi/mkl/detail/backends.hpp"
+#include "oneapi/mkl/detail/backend_selector.hpp"
+
+#include "onemkl_sparse_blas_rocsparse.hpp"
+
+namespace oneapi {
+namespace mkl {
+namespace sparse {
+
+#define BACKEND rocsparse
+#include "oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx"
+#undef BACKEND
+
+} //namespace sparse
+} //namespace mkl
+} //namespace oneapi
+
+#endif // _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c363d8a8d..01365fd14 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -60,6 +60,7 @@ function(generate_header_file)
set(ONEMKL_ENABLE_ROCFFT_BACKEND ${ENABLE_ROCFFT_BACKEND})
set(ONEMKL_ENABLE_PORTFFT_BACKEND ${ENABLE_PORTFFT_BACKEND})
set(ONEMKL_ENABLE_CUSPARSE_BACKEND ${ENABLE_CUSPARSE_BACKEND})
+ set(ONEMKL_ENABLE_ROCSPARSE_BACKEND ${ENABLE_ROCSPARSE_BACKEND})
configure_file(config.hpp.in "${CMAKE_CURRENT_BINARY_DIR}/oneapi/mkl/config.hpp.configured")
file(GENERATE
diff --git a/src/config.hpp.in b/src/config.hpp.in
index 5d8b9a136..cafda98a8 100644
--- a/src/config.hpp.in
+++ b/src/config.hpp.in
@@ -38,6 +38,7 @@
#cmakedefine ONEMKL_ENABLE_ROCFFT_BACKEND
#cmakedefine ONEMKL_ENABLE_ROCRAND_BACKEND
#cmakedefine ONEMKL_ENABLE_ROCSOLVER_BACKEND
+#cmakedefine ONEMKL_ENABLE_ROCSPARSE_BACKEND
#cmakedefine ONEMKL_BUILD_SHARED_LIBS
#endif
diff --git a/src/sparse_blas/backends/CMakeLists.txt b/src/sparse_blas/backends/CMakeLists.txt
index baae9445d..4ee6b1dc1 100644
--- a/src/sparse_blas/backends/CMakeLists.txt
+++ b/src/sparse_blas/backends/CMakeLists.txt
@@ -31,3 +31,7 @@ endif()
if(ENABLE_CUSPARSE_BACKEND)
add_subdirectory(cusparse)
endif()
+
+if(ENABLE_ROCSPARSE_BACKEND)
+ add_subdirectory(rocsparse)
+endif()
diff --git a/src/sparse_blas/backends/common_launch_task.hpp b/src/sparse_blas/backends/common_launch_task.hpp
new file mode 100644
index 000000000..df245775a
--- /dev/null
+++ b/src/sparse_blas/backends/common_launch_task.hpp
@@ -0,0 +1,414 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_COMMON_LAUNCH_TASK_HPP_
+#define _ONEMKL_SPARSE_BLAS_BACKENDS_COMMON_LAUNCH_TASK_HPP_
+
+/// This file provide a helper function to submit host_task using buffers or USM seamlessly
+
+namespace oneapi::mkl::sparse::detail {
+
+template
+auto get_value_accessor(sycl::handler& cgh, Container container) {
+ auto buffer_ptr =
+ reinterpret_cast*>(container->value_container.buffer_ptr.get());
+ return buffer_ptr->template get_access(cgh);
+}
+
+template
+auto get_fp_accessors(sycl::handler& cgh, Ts... containers) {
+ return std::array, sizeof...(containers)>{ get_value_accessor(
+ cgh, containers)... };
+}
+
+template
+auto get_row_accessor(sycl::handler& cgh, matrix_handle_t smhandle) {
+ auto buffer_ptr =
+ reinterpret_cast*>(smhandle->row_container.buffer_ptr.get());
+ return buffer_ptr->template get_access(cgh);
+}
+
+template
+auto get_col_accessor(sycl::handler& cgh, matrix_handle_t smhandle) {
+ auto buffer_ptr =
+ reinterpret_cast*>(smhandle->col_container.buffer_ptr.get());
+ return buffer_ptr->template get_access(cgh);
+}
+
+template
+auto get_int_accessors(sycl::handler& cgh, matrix_handle_t smhandle) {
+ return std::array, 2>{ get_row_accessor(cgh, smhandle),
+ get_col_accessor(cgh, smhandle) };
+}
+
+template
+void submit_host_task(sycl::handler& cgh, sycl::queue& queue, Functor functor,
+ CaptureOnlyAcc... capture_only_accessors) {
+ // Only capture the accessors to ensure the dependencies are properly
+ // handled. The accessors's pointer have already been set to the native
+ // container types in previous functions. This assumes the underlying
+ // pointer of the buffer does not change. This is not guaranteed by the SYCL
+ // specification but should be true for all the implementations. This
+ // assumption avoids the overhead of resetting the pointer of all data
+ // handles for each enqueued command.
+ cgh.host_task([functor, queue, capture_only_accessors...](sycl::interop_handle ih) {
+ auto unused = std::make_tuple(capture_only_accessors...);
+ (void)unused;
+ functor(ih);
+ });
+}
+
+template
+void submit_host_task_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor,
+ sycl::accessor workspace_acc,
+ CaptureOnlyAcc... capture_only_accessors) {
+ // Only capture the accessors to ensure the dependencies are properly
+ // handled. The accessors's pointer have already been set to the native
+ // container types in previous functions. This assumes the underlying
+ // pointer of the buffer does not change. This is not guaranteed by the SYCL
+ // specification but should be true for all the implementations. This
+ // assumption avoids the overhead of resetting the pointer of all data
+ // handles for each enqueued command.
+ cgh.host_task(
+ [functor, queue, workspace_acc, capture_only_accessors...](sycl::interop_handle ih) {
+ auto unused = std::make_tuple(capture_only_accessors...);
+ (void)unused;
+ functor(ih, workspace_acc);
+ });
+}
+
+template
+void submit_native_command_ext(sycl::handler& cgh, sycl::queue& queue, Functor functor,
+ const std::vector& dependencies,
+ CaptureOnlyAcc... capture_only_accessors) {
+ // Only capture the accessors to ensure the dependencies are properly
+ // handled. The accessors's pointer have already been set to the native
+ // container types in previous functions. This assumes the underlying
+ // pointer of the buffer does not change. This is not guaranteed by the SYCL
+ // specification but should be true for all the implementations. This
+ // assumption avoids the overhead of resetting the pointer of all data
+ // handles for each enqueued command.
+#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
+ cgh.ext_codeplay_enqueue_native_command(
+ [functor, queue, dependencies, capture_only_accessors...](sycl::interop_handle ih) {
+ auto unused = std::make_tuple(capture_only_accessors...);
+ (void)unused;
+ // The functor using ext_codeplay_enqueue_native_command need to
+ // explicitly wait on the events for the SPARSE domain. The
+ // extension ext_codeplay_enqueue_native_command is used to launch
+ // the compute operation which depends on the previous optimize
+ // step. In cuSPARSE the optimize step is synchronous but it is
+ // asynchronous in oneMKL Interface. The optimize step may not use
+ // the CUDA stream which would make it impossible for
+ // ext_codeplay_enqueue_native_command to automatically ensure it
+ // has completed before the compute function starts. These waits are
+ // used to ensure the optimize step has completed before starting
+ // the computation.
+ for (auto event : dependencies) {
+ event.wait();
+ }
+ functor(ih);
+ });
+#else
+ (void)dependencies;
+ submit_host_task(cgh, queue, functor, capture_only_accessors...);
+#endif
+}
+
+template
+void submit_native_command_ext_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor,
+ const std::vector& dependencies,
+ sycl::accessor workspace_acc,
+ CaptureOnlyAcc... capture_only_accessors) {
+ // Only capture the accessors to ensure the dependencies are properly
+ // handled. The accessors's pointer have already been set to the native
+ // container types in previous functions. This assumes the underlying
+ // pointer of the buffer does not change. This is not guaranteed by the SYCL
+ // specification but should be true for all the implementations. This
+ // assumption avoids the overhead of resetting the pointer of all data
+ // handles for each enqueued command.
+#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
+ cgh.ext_codeplay_enqueue_native_command([functor, queue, dependencies, workspace_acc,
+ capture_only_accessors...](sycl::interop_handle ih) {
+ auto unused = std::make_tuple(capture_only_accessors...);
+ (void)unused;
+ // The functor using ext_codeplay_enqueue_native_command need to
+ // explicitly wait on the events for the SPARSE domain. The
+ // extension ext_codeplay_enqueue_native_command is used to launch
+ // the compute operation which depends on the previous optimize
+ // step. In cuSPARSE the optimize step is synchronous but it is
+ // asynchronous in oneMKL Interface. The optimize step may not use
+ // the CUDA stream which would make it impossible for
+ // ext_codeplay_enqueue_native_command to automatically ensure it
+ // has completed before the compute function starts. These waits are
+ // used to ensure the optimize step has completed before starting
+ // the computation.
+ for (auto event : dependencies) {
+ event.wait();
+ }
+ functor(ih, workspace_acc);
+ });
+#else
+ (void)dependencies;
+ submit_host_task_with_acc(cgh, queue, functor, workspace_acc, capture_only_accessors...);
+#endif
+}
+
+/// Helper submit functions to capture all accessors from the generic containers
+/// \p other_containers and ensure the dependencies of buffers are respected.
+/// The accessors are not directly used as the underlying data pointer has
+/// already been captured in previous functions.
+/// \p workspace_buffer is an optional buffer. Its accessor will be given to the
+/// functor as a last argument if \p UseWorkspace is true.
+/// \p UseWorkspace must be true to use the given \p workspace_buffer.
+/// \p UseEnqueueNativeCommandExt controls whether host_task are used or the
+/// extension ext_codeplay_enqueue_native_command is used to launch tasks. The
+/// extension should only be used for asynchronous functions using native
+/// backend's functions. The extension can only be used for in-order queues as
+/// the same cuStream needs to be used for the 3 steps to run an operation:
+/// querying the buffer size, optimizing and running the computation. This means
+/// a different cuStream can be used inside the native_command than the native
+/// cuStream used by the extension.
+template
+sycl::event dispatch_submit_impl_fp_int(const std::string& function_name, sycl::queue queue,
+ const std::vector& dependencies,
+ Functor functor, matrix_handle_t sm_handle,
+ sycl::buffer workspace_buffer,
+ Ts... other_containers) {
+ bool is_in_order_queue = queue.is_in_order();
+ if (sm_handle->all_use_buffer()) {
+ data_type value_type = sm_handle->get_value_type();
+ data_type int_type = sm_handle->get_int_type();
+
+#define ONEMKL_SUBMIT(FP_TYPE, INT_TYPE) \
+ return queue.submit([&](sycl::handler& cgh) { \
+ cgh.depends_on(dependencies); \
+ auto fp_accs = get_fp_accessors(cgh, sm_handle, other_containers...); \
+ auto int_accs = get_int_accessors(cgh, sm_handle); \
+ auto workspace_acc = workspace_buffer.get_access(cgh); \
+ if constexpr (UseWorkspace) { \
+ if constexpr (UseEnqueueNativeCommandExt) { \
+ if (is_in_order_queue) { \
+ submit_native_command_ext_with_acc(cgh, queue, functor, dependencies, \
+ workspace_acc, fp_accs, int_accs); \
+ } \
+ else { \
+ submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, \
+ int_accs); \
+ } \
+ } \
+ else { \
+ submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, int_accs); \
+ } \
+ } \
+ else { \
+ (void)workspace_buffer; \
+ if constexpr (UseEnqueueNativeCommandExt) { \
+ if (is_in_order_queue) { \
+ submit_native_command_ext(cgh, queue, functor, dependencies, fp_accs, \
+ int_accs); \
+ } \
+ else { \
+ submit_host_task(cgh, queue, functor, fp_accs, int_accs); \
+ } \
+ } \
+ else { \
+ submit_host_task(cgh, queue, functor, fp_accs, int_accs); \
+ } \
+ } \
+ })
+#define ONEMKL_SUBMIT_INT(FP_TYPE) \
+ if (int_type == data_type::int32) { \
+ ONEMKL_SUBMIT(FP_TYPE, std::int32_t); \
+ } \
+ else if (int_type == data_type::int64) { \
+ ONEMKL_SUBMIT(FP_TYPE, std::int64_t); \
+ }
+
+ if (value_type == data_type::real_fp32) {
+ ONEMKL_SUBMIT_INT(float)
+ }
+ else if (value_type == data_type::real_fp64) {
+ ONEMKL_SUBMIT_INT(double)
+ }
+ else if (value_type == data_type::complex_fp32) {
+ ONEMKL_SUBMIT_INT(std::complex)
+ }
+ else if (value_type == data_type::complex_fp64) {
+ ONEMKL_SUBMIT_INT(std::complex)
+ }
+
+#undef ONEMKL_SUBMIT_INT
+#undef ONEMKL_SUBMIT
+
+ throw oneapi::mkl::exception("sparse_blas", function_name,
+ "Could not dispatch buffer kernel to a supported type");
+ }
+ else {
+ // USM submit does not need to capture accessors
+ if constexpr (!UseWorkspace) {
+ return queue.submit([&](sycl::handler& cgh) {
+ cgh.depends_on(dependencies);
+ if constexpr (UseEnqueueNativeCommandExt) {
+ if (is_in_order_queue) {
+ submit_native_command_ext(cgh, queue, functor, dependencies);
+ }
+ else {
+ submit_host_task(cgh, queue, functor);
+ }
+ }
+ else {
+ submit_host_task(cgh, queue, functor);
+ }
+ });
+ }
+ else {
+ throw oneapi::mkl::exception("sparse_blas", function_name,
+ "Internal error: Cannot use accessor workspace with USM");
+ }
+ }
+}
+
+/// Similar to dispatch_submit_impl_fp_int but only dispatches the host_task based on the floating point value type.
+template
+sycl::event dispatch_submit_impl_fp(const std::string& function_name, sycl::queue queue,
+ const std::vector& dependencies, Functor functor,
+ ContainerT container_handle) {
+ if (container_handle->all_use_buffer()) {
+ data_type value_type = container_handle->get_value_type();
+
+#define ONEMKL_SUBMIT(FP_TYPE) \
+ return queue.submit([&](sycl::handler& cgh) { \
+ cgh.depends_on(dependencies); \
+ auto fp_accs = get_fp_accessors(cgh, container_handle); \
+ submit_host_task(cgh, queue, functor, fp_accs); \
+ })
+
+ if (value_type == data_type::real_fp32) {
+ ONEMKL_SUBMIT(float);
+ }
+ else if (value_type == data_type::real_fp64) {
+ ONEMKL_SUBMIT(double);
+ }
+ else if (value_type == data_type::complex_fp32) {
+ ONEMKL_SUBMIT(std::complex);
+ }
+ else if (value_type == data_type::complex_fp64) {
+ ONEMKL_SUBMIT(std::complex);
+ }
+
+#undef ONEMKL_SUBMIT
+
+ throw oneapi::mkl::exception("sparse_blas", function_name,
+ "Could not dispatch buffer kernel to a supported type");
+ }
+ else {
+ return queue.submit([&](sycl::handler& cgh) {
+ cgh.depends_on(dependencies);
+ submit_host_task(cgh, queue, functor);
+ });
+ }
+}
+
+/// Helper function for dispatch_submit_impl_fp_int
+template
+sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor,
+ matrix_handle_t sm_handle, sycl::buffer workspace_buffer,
+ Ts... other_containers) {
+ constexpr bool UseWorkspace = true;
+ constexpr bool UseEnqueueNativeCommandExt = false;
+ return dispatch_submit_impl_fp_int(
+ function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...);
+}
+
+/// Helper function for dispatch_submit_impl_fp_int
+template
+sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue,
+ const std::vector& dependencies, Functor functor,
+ matrix_handle_t sm_handle, Ts... other_containers) {
+ constexpr bool UseWorkspace = false;
+ constexpr bool UseEnqueueNativeCommandExt = false;
+ sycl::buffer no_workspace(sycl::range<1>(0));
+ return dispatch_submit_impl_fp_int(
+ function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...);
+}
+
+/// Helper function for dispatch_submit_impl_fp_int
+template
+sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor,
+ matrix_handle_t sm_handle, Ts... other_containers) {
+ constexpr bool UseWorkspace = false;
+ constexpr bool UseEnqueueNativeCommandExt = false;
+ sycl::buffer no_workspace(sycl::range<1>(0));
+ return dispatch_submit_impl_fp_int(
+ function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...);
+}
+
+/// Helper function for dispatch_submit_impl_fp_int
+template
+sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue,
+ Functor functor, matrix_handle_t sm_handle,
+ sycl::buffer workspace_buffer,
+ Ts... other_containers) {
+ constexpr bool UseWorkspace = true;
+#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
+ constexpr bool UseEnqueueNativeCommandExt = true;
+#else
+ constexpr bool UseEnqueueNativeCommandExt = false;
+#endif
+ return dispatch_submit_impl_fp_int(
+ function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...);
+}
+
+/// Helper function for dispatch_submit_impl_fp_int
+template
+sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue,
+ const std::vector& dependencies,
+ Functor functor, matrix_handle_t sm_handle,
+ Ts... other_containers) {
+ constexpr bool UseWorkspace = false;
+#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
+ constexpr bool UseEnqueueNativeCommandExt = true;
+#else
+ constexpr bool UseEnqueueNativeCommandExt = false;
+#endif
+ sycl::buffer no_workspace(sycl::range<1>(0));
+ return dispatch_submit_impl_fp_int(
+ function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...);
+}
+
+/// Helper function for dispatch_submit_impl_fp_int
+template
+sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue,
+ Functor functor, matrix_handle_t sm_handle,
+ Ts... other_containers) {
+ constexpr bool UseWorkspace = false;
+#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
+ constexpr bool UseEnqueueNativeCommandExt = true;
+#else
+ constexpr bool UseEnqueueNativeCommandExt = false;
+#endif
+ sycl::buffer no_workspace(sycl::range<1>(0));
+ return dispatch_submit_impl_fp_int(
+ function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...);
+}
+
+} // namespace oneapi::mkl::sparse::detail
+
+#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_COMMON_LAUNCH_TASK_HPP_
diff --git a/src/sparse_blas/backends/cusparse/cusparse_handles.cpp b/src/sparse_blas/backends/cusparse/cusparse_handles.cpp
index ff3d8fcae..a6c803248 100644
--- a/src/sparse_blas/backends/cusparse/cusparse_handles.cpp
+++ b/src/sparse_blas/backends/cusparse/cusparse_handles.cpp
@@ -22,6 +22,7 @@
#include "cusparse_error.hpp"
#include "cusparse_helper.hpp"
#include "cusparse_handles.hpp"
+#include "cusparse_scope_handle.hpp"
#include "cusparse_task.hpp"
#include "sparse_blas/macros.hpp"
diff --git a/src/sparse_blas/backends/cusparse/cusparse_task.hpp b/src/sparse_blas/backends/cusparse/cusparse_task.hpp
index 0d86d642d..edfb39064 100644
--- a/src/sparse_blas/backends/cusparse/cusparse_task.hpp
+++ b/src/sparse_blas/backends/cusparse/cusparse_task.hpp
@@ -17,401 +17,14 @@
*
**************************************************************************/
-#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_
-#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_
+#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASK_HPP_
+#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASK_HPP_
-#include "cusparse_handles.hpp"
-#include "cusparse_scope_handle.hpp"
-
-/// This file provide a helper function to submit host_task using buffers or USM seamlessly
+#include "cusparse_error.hpp"
+#include "sparse_blas/backends/common_launch_task.hpp"
namespace oneapi::mkl::sparse::cusparse::detail {
-template
-auto get_value_accessor(sycl::handler& cgh, Container container) {
- auto buffer_ptr =
- reinterpret_cast*>(container->value_container.buffer_ptr.get());
- return buffer_ptr->template get_access(cgh);
-}
-
-template
-auto get_fp_accessors(sycl::handler& cgh, Ts... containers) {
- return std::array, sizeof...(containers)>{ get_value_accessor(
- cgh, containers)... };
-}
-
-template
-auto get_row_accessor(sycl::handler& cgh, matrix_handle_t smhandle) {
- auto buffer_ptr =
- reinterpret_cast*>(smhandle->row_container.buffer_ptr.get());
- return buffer_ptr->template get_access(cgh);
-}
-
-template
-auto get_col_accessor(sycl::handler& cgh, matrix_handle_t smhandle) {
- auto buffer_ptr =
- reinterpret_cast*>(smhandle->col_container.buffer_ptr.get());
- return buffer_ptr->template get_access(cgh);
-}
-
-template
-auto get_int_accessors(sycl::handler& cgh, matrix_handle_t smhandle) {
- return std::array, 2>{ get_row_accessor(cgh, smhandle),
- get_col_accessor(cgh, smhandle) };
-}
-
-template
-void submit_host_task(sycl::handler& cgh, sycl::queue& queue, Functor functor,
- CaptureOnlyAcc... capture_only_accessors) {
- // Only capture the accessors to ensure the dependencies are properly
- // handled. The accessors's pointer have already been set to the native
- // container types in previous functions. This assumes the underlying
- // pointer of the buffer does not change. This is not guaranteed by the SYCL
- // specification but should be true for all the implementations. This
- // assumption avoids the overhead of resetting the pointer of all data
- // handles for each enqueued command.
- cgh.host_task([functor, queue, capture_only_accessors...](sycl::interop_handle ih) {
- auto unused = std::make_tuple(capture_only_accessors...);
- (void)unused;
- functor(ih);
- });
-}
-
-template
-void submit_host_task_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor,
- sycl::accessor workspace_acc,
- CaptureOnlyAcc... capture_only_accessors) {
- // Only capture the accessors to ensure the dependencies are properly
- // handled. The accessors's pointer have already been set to the native
- // container types in previous functions. This assumes the underlying
- // pointer of the buffer does not change. This is not guaranteed by the SYCL
- // specification but should be true for all the implementations. This
- // assumption avoids the overhead of resetting the pointer of all data
- // handles for each enqueued command.
- cgh.host_task(
- [functor, queue, workspace_acc, capture_only_accessors...](sycl::interop_handle ih) {
- auto unused = std::make_tuple(capture_only_accessors...);
- (void)unused;
- functor(ih, workspace_acc);
- });
-}
-
-template
-void submit_native_command_ext(sycl::handler& cgh, sycl::queue& queue, Functor functor,
- const std::vector& dependencies,
- CaptureOnlyAcc... capture_only_accessors) {
- // Only capture the accessors to ensure the dependencies are properly
- // handled. The accessors's pointer have already been set to the native
- // container types in previous functions. This assumes the underlying
- // pointer of the buffer does not change. This is not guaranteed by the SYCL
- // specification but should be true for all the implementations. This
- // assumption avoids the overhead of resetting the pointer of all data
- // handles for each enqueued command.
-#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
- cgh.ext_codeplay_enqueue_native_command(
- [functor, queue, dependencies, capture_only_accessors...](sycl::interop_handle ih) {
- auto unused = std::make_tuple(capture_only_accessors...);
- (void)unused;
- // The functor using ext_codeplay_enqueue_native_command need to
- // explicitly wait on the events for the SPARSE domain. The
- // extension ext_codeplay_enqueue_native_command is used to launch
- // the compute operation which depends on the previous optimize
- // step. In cuSPARSE the optimize step is synchronous but it is
- // asynchronous in oneMKL Interface. The optimize step may not use
- // the CUDA stream which would make it impossible for
- // ext_codeplay_enqueue_native_command to automatically ensure it
- // has completed before the compute function starts. These waits are
- // used to ensure the optimize step has completed before starting
- // the computation.
- for (auto event : dependencies) {
- event.wait();
- }
- functor(ih);
- });
-#else
- (void)dependencies;
- submit_host_task(cgh, queue, functor, capture_only_accessors...);
-#endif
-}
-
-template
-void submit_native_command_ext_with_acc(sycl::handler& cgh, sycl::queue& queue, Functor functor,
- const std::vector& dependencies,
- sycl::accessor workspace_acc,
- CaptureOnlyAcc... capture_only_accessors) {
- // Only capture the accessors to ensure the dependencies are properly
- // handled. The accessors's pointer have already been set to the native
- // container types in previous functions. This assumes the underlying
- // pointer of the buffer does not change. This is not guaranteed by the SYCL
- // specification but should be true for all the implementations. This
- // assumption avoids the overhead of resetting the pointer of all data
- // handles for each enqueued command.
-#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
- cgh.ext_codeplay_enqueue_native_command([functor, queue, dependencies, workspace_acc,
- capture_only_accessors...](sycl::interop_handle ih) {
- auto unused = std::make_tuple(capture_only_accessors...);
- (void)unused;
- // The functor using ext_codeplay_enqueue_native_command need to
- // explicitly wait on the events for the SPARSE domain. The
- // extension ext_codeplay_enqueue_native_command is used to launch
- // the compute operation which depends on the previous optimize
- // step. In cuSPARSE the optimize step is synchronous but it is
- // asynchronous in oneMKL Interface. The optimize step may not use
- // the CUDA stream which would make it impossible for
- // ext_codeplay_enqueue_native_command to automatically ensure it
- // has completed before the compute function starts. These waits are
- // used to ensure the optimize step has completed before starting
- // the computation.
- for (auto event : dependencies) {
- event.wait();
- }
- functor(ih, workspace_acc);
- });
-#else
- (void)dependencies;
- submit_host_task_with_acc(cgh, queue, functor, workspace_acc, capture_only_accessors...);
-#endif
-}
-
-/// Helper submit functions to capture all accessors from the generic containers
-/// \p other_containers and ensure the dependencies of buffers are respected.
-/// The accessors are not directly used as the underlying data pointer has
-/// already been captured in previous functions.
-/// \p workspace_buffer is an optional buffer. Its accessor will be given to the
-/// functor as a last argument if \p UseWorkspace is true.
-/// \p UseWorkspace must be true to use the given \p workspace_buffer.
-/// \p UseEnqueueNativeCommandExt controls whether host_task are used or the
-/// extension ext_codeplay_enqueue_native_command is used to launch tasks. The
-/// extension should only be used for asynchronous functions using native
-/// backend's functions. The extension can only be used for in-order queues as
-/// the same cuStream needs to be used for the 3 steps to run an operation:
-/// querying the buffer size, optimizing and running the computation. This means
-/// a different cuStream can be used inside the native_command than the native
-/// cuStream used by the extension.
-template
-sycl::event dispatch_submit_impl_fp_int(const std::string& function_name, sycl::queue queue,
- const std::vector& dependencies,
- Functor functor, matrix_handle_t sm_handle,
- sycl::buffer workspace_buffer,
- Ts... other_containers) {
- bool is_in_order_queue = queue.is_in_order();
- if (sm_handle->all_use_buffer()) {
- data_type value_type = sm_handle->get_value_type();
- data_type int_type = sm_handle->get_int_type();
-
-#define ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, INT_TYPE) \
- return queue.submit([&](sycl::handler& cgh) { \
- cgh.depends_on(dependencies); \
- auto fp_accs = get_fp_accessors(cgh, sm_handle, other_containers...); \
- auto int_accs = get_int_accessors(cgh, sm_handle); \
- auto workspace_acc = workspace_buffer.get_access(cgh); \
- if constexpr (UseWorkspace) { \
- if constexpr (UseEnqueueNativeCommandExt) { \
- if (is_in_order_queue) { \
- submit_native_command_ext_with_acc(cgh, queue, functor, dependencies, \
- workspace_acc, fp_accs, int_accs); \
- } \
- else { \
- submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, \
- int_accs); \
- } \
- } \
- else { \
- submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, int_accs); \
- } \
- } \
- else { \
- (void)workspace_buffer; \
- if constexpr (UseEnqueueNativeCommandExt) { \
- if (is_in_order_queue) { \
- submit_native_command_ext(cgh, queue, functor, dependencies, fp_accs, \
- int_accs); \
- } \
- else { \
- submit_host_task(cgh, queue, functor, fp_accs, int_accs); \
- } \
- } \
- else { \
- submit_host_task(cgh, queue, functor, fp_accs, int_accs); \
- } \
- } \
- })
-#define ONEMKL_CUSPARSE_SUBMIT_INT(FP_TYPE) \
- if (int_type == data_type::int32) { \
- ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, std::int32_t); \
- } \
- else if (int_type == data_type::int64) { \
- ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, std::int64_t); \
- }
-
- if (value_type == data_type::real_fp32) {
- ONEMKL_CUSPARSE_SUBMIT_INT(float)
- }
- else if (value_type == data_type::real_fp64) {
- ONEMKL_CUSPARSE_SUBMIT_INT(double)
- }
- else if (value_type == data_type::complex_fp32) {
- ONEMKL_CUSPARSE_SUBMIT_INT(std::complex)
- }
- else if (value_type == data_type::complex_fp64) {
- ONEMKL_CUSPARSE_SUBMIT_INT(std::complex)
- }
-
-#undef ONEMKL_CUSPARSE_SUBMIT_INT
-#undef ONEMKL_CUSPARSE_SUBMIT
-
- throw oneapi::mkl::exception("sparse_blas", function_name,
- "Could not dispatch buffer kernel to a supported type");
- }
- else {
- // USM submit does not need to capture accessors
- if constexpr (!UseWorkspace) {
- return queue.submit([&](sycl::handler& cgh) {
- cgh.depends_on(dependencies);
- if constexpr (UseEnqueueNativeCommandExt) {
- if (is_in_order_queue) {
- submit_native_command_ext(cgh, queue, functor, dependencies);
- }
- else {
- submit_host_task(cgh, queue, functor);
- }
- }
- else {
- submit_host_task(cgh, queue, functor);
- }
- });
- }
- else {
- throw oneapi::mkl::exception("sparse_blas", function_name,
- "Internal error: Cannot use accessor workspace with USM");
- }
- }
-}
-
-/// Similar to dispatch_submit_impl_fp_int but only dispatches the host_task based on the floating point value type.
-template
-sycl::event dispatch_submit_impl_fp(const std::string& function_name, sycl::queue queue,
- const std::vector& dependencies, Functor functor,
- ContainerT container_handle) {
- if (container_handle->all_use_buffer()) {
- data_type value_type = container_handle->get_value_type();
-
-#define ONEMKL_CUSPARSE_SUBMIT(FP_TYPE) \
- return queue.submit([&](sycl::handler& cgh) { \
- cgh.depends_on(dependencies); \
- auto fp_accs = get_fp_accessors(cgh, container_handle); \
- submit_host_task(cgh, queue, functor, fp_accs); \
- })
-
- if (value_type == data_type::real_fp32) {
- ONEMKL_CUSPARSE_SUBMIT(float);
- }
- else if (value_type == data_type::real_fp64) {
- ONEMKL_CUSPARSE_SUBMIT(double);
- }
- else if (value_type == data_type::complex_fp32) {
- ONEMKL_CUSPARSE_SUBMIT(std::complex);
- }
- else if (value_type == data_type::complex_fp64) {
- ONEMKL_CUSPARSE_SUBMIT(std::complex);
- }
-
-#undef ONEMKL_CUSPARSE_SUBMIT
-
- throw oneapi::mkl::exception("sparse_blas", function_name,
- "Could not dispatch buffer kernel to a supported type");
- }
- else {
- return queue.submit([&](sycl::handler& cgh) {
- cgh.depends_on(dependencies);
- submit_host_task(cgh, queue, functor);
- });
- }
-}
-
-/// Helper function for dispatch_submit_impl_fp_int
-template
-sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor,
- matrix_handle_t sm_handle, sycl::buffer workspace_buffer,
- Ts... other_containers) {
- constexpr bool UseWorkspace = true;
- constexpr bool UseEnqueueNativeCommandExt = false;
- return dispatch_submit_impl_fp_int(
- function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...);
-}
-
-/// Helper function for dispatch_submit_impl_fp_int
-template
-sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue,
- const std::vector& dependencies, Functor functor,
- matrix_handle_t sm_handle, Ts... other_containers) {
- constexpr bool UseWorkspace = false;
- constexpr bool UseEnqueueNativeCommandExt = false;
- sycl::buffer no_workspace(sycl::range<1>(0));
- return dispatch_submit_impl_fp_int(
- function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...);
-}
-
-/// Helper function for dispatch_submit_impl_fp_int
-template
-sycl::event dispatch_submit(const std::string& function_name, sycl::queue queue, Functor functor,
- matrix_handle_t sm_handle, Ts... other_containers) {
- constexpr bool UseWorkspace = false;
- constexpr bool UseEnqueueNativeCommandExt = false;
- sycl::buffer no_workspace(sycl::range<1>(0));
- return dispatch_submit_impl_fp_int(
- function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...);
-}
-
-/// Helper function for dispatch_submit_impl_fp_int
-template
-sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue,
- Functor functor, matrix_handle_t sm_handle,
- sycl::buffer workspace_buffer,
- Ts... other_containers) {
- constexpr bool UseWorkspace = true;
-#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
- constexpr bool UseEnqueueNativeCommandExt = true;
-#else
- constexpr bool UseEnqueueNativeCommandExt = false;
-#endif
- return dispatch_submit_impl_fp_int(
- function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...);
-}
-
-/// Helper function for dispatch_submit_impl_fp_int
-template
-sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue,
- const std::vector& dependencies,
- Functor functor, matrix_handle_t sm_handle,
- Ts... other_containers) {
- constexpr bool UseWorkspace = false;
-#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
- constexpr bool UseEnqueueNativeCommandExt = true;
-#else
- constexpr bool UseEnqueueNativeCommandExt = false;
-#endif
- sycl::buffer no_workspace(sycl::range<1>(0));
- return dispatch_submit_impl_fp_int(
- function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...);
-}
-
-/// Helper function for dispatch_submit_impl_fp_int
-template
-sycl::event dispatch_submit_native_ext(const std::string& function_name, sycl::queue queue,
- Functor functor, matrix_handle_t sm_handle,
- Ts... other_containers) {
- constexpr bool UseWorkspace = false;
-#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
- constexpr bool UseEnqueueNativeCommandExt = true;
-#else
- constexpr bool UseEnqueueNativeCommandExt = false;
-#endif
- sycl::buffer no_workspace(sycl::range<1>(0));
- return dispatch_submit_impl_fp_int(
- function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...);
-}
-
// Helper function for functors submitted to host_task or native_command.
// When the extension is disabled, host_task are used and the synchronization is needed to ensure the sycl::event corresponds to the end of the whole functor.
// When the extension is enabled, host_task are still used for out-of-order queues, see description of dispatch_submit_impl_fp_int.
@@ -428,4 +41,4 @@ inline void synchronize_if_needed(bool is_in_order_queue, CUstream cu_stream) {
} // namespace oneapi::mkl::sparse::cusparse::detail
-#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_
+#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASK_HPP_
diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp
index 5fd24d3f4..5d1cd0290 100644
--- a/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp
+++ b/src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp
@@ -20,9 +20,10 @@
#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp"
#include "sparse_blas/backends/cusparse/cusparse_error.hpp"
+#include "sparse_blas/backends/cusparse/cusparse_handles.hpp"
#include "sparse_blas/backends/cusparse/cusparse_helper.hpp"
#include "sparse_blas/backends/cusparse/cusparse_task.hpp"
-#include "sparse_blas/backends/cusparse/cusparse_handles.hpp"
+#include "sparse_blas/backends/cusparse/cusparse_scope_handle.hpp"
#include "sparse_blas/common_op_verification.hpp"
#include "sparse_blas/macros.hpp"
#include "sparse_blas/matrix_view_comparison.hpp"
diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp
index 03b848916..44b6f5a01 100644
--- a/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp
+++ b/src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp
@@ -20,9 +20,10 @@
#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp"
#include "sparse_blas/backends/cusparse/cusparse_error.hpp"
+#include "sparse_blas/backends/cusparse/cusparse_handles.hpp"
#include "sparse_blas/backends/cusparse/cusparse_helper.hpp"
#include "sparse_blas/backends/cusparse/cusparse_task.hpp"
-#include "sparse_blas/backends/cusparse/cusparse_handles.hpp"
+#include "sparse_blas/backends/cusparse/cusparse_scope_handle.hpp"
#include "sparse_blas/common_op_verification.hpp"
#include "sparse_blas/macros.hpp"
#include "sparse_blas/matrix_view_comparison.hpp"
diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
index 5c49df013..e965a4dcb 100644
--- a/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
+++ b/src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
@@ -20,9 +20,10 @@
#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp"
#include "sparse_blas/backends/cusparse/cusparse_error.hpp"
+#include "sparse_blas/backends/cusparse/cusparse_handles.hpp"
#include "sparse_blas/backends/cusparse/cusparse_helper.hpp"
#include "sparse_blas/backends/cusparse/cusparse_task.hpp"
-#include "sparse_blas/backends/cusparse/cusparse_handles.hpp"
+#include "sparse_blas/backends/cusparse/cusparse_scope_handle.hpp"
#include "sparse_blas/common_op_verification.hpp"
#include "sparse_blas/macros.hpp"
#include "sparse_blas/matrix_view_comparison.hpp"
diff --git a/src/sparse_blas/backends/rocsparse/CMakeLists.txt b/src/sparse_blas/backends/rocsparse/CMakeLists.txt
new file mode 100644
index 000000000..af26b50eb
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/CMakeLists.txt
@@ -0,0 +1,81 @@
+#===============================================================================
+# Copyright 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions
+# and limitations under the License.
+#
+#
+# SPDX-License-Identifier: Apache-2.0
+#===============================================================================
+
+set(LIB_NAME onemkl_sparse_blas_rocsparse)
+set(LIB_OBJ ${LIB_NAME}_obj)
+
+include(WarningsUtils)
+
+add_library(${LIB_NAME})
+add_library(${LIB_OBJ} OBJECT
+ rocsparse_handles.cpp
+ rocsparse_scope_handle.cpp
+ operations/rocsparse_spmm.cpp
+ operations/rocsparse_spmv.cpp
+ operations/rocsparse_spsv.cpp
+ $<$: rocsparse_wrappers.cpp>
+)
+add_dependencies(onemkl_backend_libs_sparse_blas ${LIB_NAME})
+
+target_include_directories(${LIB_OBJ}
+ PRIVATE ${PROJECT_SOURCE_DIR}/include
+ ${PROJECT_SOURCE_DIR}/src
+ ${CMAKE_BINARY_DIR}/bin
+ ${ONEMKL_GENERATED_INCLUDE_PATH}
+)
+
+target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT})
+
+find_package(HIP REQUIRED)
+find_package(rocsparse REQUIRED)
+
+target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocsparse)
+
+target_link_libraries(${LIB_OBJ}
+ PUBLIC ONEMKL::SYCL::SYCL
+ PRIVATE onemkl_warnings
+)
+
+set_target_properties(${LIB_OBJ} PROPERTIES
+ POSITION_INDEPENDENT_CODE ON
+)
+target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ})
+
+#Set oneMKL libraries as not transitive for dynamic
+if(BUILD_SHARED_LIBS)
+ set_target_properties(${LIB_NAME} PROPERTIES
+ INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL
+ )
+endif()
+
+# Add major version to the library
+set_target_properties(${LIB_NAME} PROPERTIES
+ SOVERSION ${PROJECT_VERSION_MAJOR}
+)
+
+# Add dependencies rpath to the library
+list(APPEND CMAKE_BUILD_RPATH $)
+
+# Add the library to install package
+install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets)
+install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets
+ RUNTIME DESTINATION bin
+ ARCHIVE DESTINATION lib
+ LIBRARY DESTINATION lib
+)
diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp
new file mode 100644
index 000000000..2a336d22d
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp
@@ -0,0 +1,350 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp"
+
+#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp"
+#include "sparse_blas/common_op_verification.hpp"
+#include "sparse_blas/macros.hpp"
+#include "sparse_blas/matrix_view_comparison.hpp"
+#include "sparse_blas/sycl_helper.hpp"
+
+namespace oneapi::mkl::sparse {
+
+// Complete the definition of the incomplete type
+struct spmm_descr {
+ // Cache the hipStream_t and global handle to avoid relying on RocsparseScopedContextHandler to retrieve them.
+ hipStream_t hip_stream;
+ rocsparse_handle roc_handle;
+
+ detail::generic_container workspace;
+ std::size_t temp_buffer_size = 0;
+ bool buffer_size_called = false;
+ bool optimized_called = false;
+ oneapi::mkl::transpose last_optimized_opA;
+ oneapi::mkl::transpose last_optimized_opB;
+ oneapi::mkl::sparse::matrix_view last_optimized_A_view;
+ oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle;
+ oneapi::mkl::sparse::dense_matrix_handle_t last_optimized_B_handle;
+ oneapi::mkl::sparse::dense_matrix_handle_t last_optimized_C_handle;
+ oneapi::mkl::sparse::spmm_alg last_optimized_alg;
+};
+
+} // namespace oneapi::mkl::sparse
+
+namespace oneapi::mkl::sparse::rocsparse {
+
+namespace detail {
+
+inline auto get_roc_spmm_alg(spmm_alg alg) {
+ switch (alg) {
+ case spmm_alg::coo_alg1: return rocsparse_spmm_alg_coo_segmented;
+ case spmm_alg::coo_alg2: return rocsparse_spmm_alg_coo_atomic;
+ case spmm_alg::coo_alg3: return rocsparse_spmm_alg_coo_segmented_atomic;
+ case spmm_alg::csr_alg1: return rocsparse_spmm_alg_csr;
+ case spmm_alg::csr_alg2: return rocsparse_spmm_alg_csr_row_split;
+ case spmm_alg::csr_alg3: return rocsparse_spmm_alg_csr_merge;
+ default: return rocsparse_spmm_alg_default;
+ }
+}
+
+void check_valid_spmm(const std::string& function_name, matrix_view A_view,
+ matrix_handle_t A_handle, dense_matrix_handle_t B_handle,
+ dense_matrix_handle_t C_handle, bool is_alpha_host_accessible,
+ bool is_beta_host_accessible) {
+ check_valid_spmm_common(function_name, A_view, A_handle, B_handle, C_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ A_handle->check_valid_handle(function_name);
+}
+
+inline void common_spmm_optimize(
+ oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, bool is_alpha_host_accessible,
+ oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_matrix_handle_t B_handle, bool is_beta_host_accessible,
+ oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg,
+ oneapi::mkl::sparse::spmm_descr_t spmm_descr) {
+ check_valid_spmm("spmm_optimize", A_view, A_handle, B_handle, C_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ if (!spmm_descr->buffer_size_called) {
+ throw mkl::uninitialized("sparse_blas", "spmm_optimize",
+ "spmm_buffer_size must be called before spmm_optimize.");
+ }
+ spmm_descr->optimized_called = true;
+ spmm_descr->last_optimized_opA = opA;
+ spmm_descr->last_optimized_opB = opB;
+ spmm_descr->last_optimized_A_view = A_view;
+ spmm_descr->last_optimized_A_handle = A_handle;
+ spmm_descr->last_optimized_B_handle = B_handle;
+ spmm_descr->last_optimized_C_handle = C_handle;
+ spmm_descr->last_optimized_alg = alg;
+}
+
+void spmm_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA,
+ oneapi::mkl::transpose opB, const void* alpha,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta,
+ oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
+ oneapi::mkl::sparse::spmm_alg alg, std::size_t buffer_size,
+ void* workspace_ptr, bool is_alpha_host_accessible) {
+ auto roc_a = A_handle->backend_handle;
+ auto roc_b = B_handle->backend_handle;
+ auto roc_c = C_handle->backend_handle;
+ auto roc_op_a = get_roc_operation(opA);
+ auto roc_op_b = get_roc_operation(opB);
+ auto roc_type = get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = get_roc_spmm_alg(alg);
+ set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ // rocsparse_spmm_stage_preprocess stage is blocking
+ auto status =
+ rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, roc_c, roc_type,
+ roc_alg, rocsparse_spmm_stage_preprocess, &buffer_size, workspace_ptr);
+ check_status(status, "spmm_optimize");
+}
+
+} // namespace detail
+
+void init_spmm_descr(sycl::queue& /*queue*/, spmm_descr_t* p_spmm_descr) {
+ *p_spmm_descr = new spmm_descr();
+}
+
+sycl::event release_spmm_descr(sycl::queue& queue, spmm_descr_t spmm_descr,
+ const std::vector& dependencies) {
+ if (!spmm_descr) {
+ return detail::collapse_dependencies(queue, dependencies);
+ }
+
+ auto release_functor = [=]() {
+ spmm_descr->roc_handle = nullptr;
+ spmm_descr->last_optimized_A_handle = nullptr;
+ spmm_descr->last_optimized_B_handle = nullptr;
+ spmm_descr->last_optimized_C_handle = nullptr;
+ delete spmm_descr;
+ };
+
+ // Use dispatch_submit to ensure the descriptor is kept alive as long as the buffers are used
+ // dispatch_submit can only be used if the descriptor's handles are valid
+ if (spmm_descr->last_optimized_A_handle &&
+ spmm_descr->last_optimized_A_handle->all_use_buffer() &&
+ spmm_descr->last_optimized_B_handle && spmm_descr->last_optimized_C_handle &&
+ spmm_descr->workspace.use_buffer()) {
+ auto dispatch_functor = [=](sycl::interop_handle, sycl::accessor) {
+ release_functor();
+ };
+ return detail::dispatch_submit(
+ __func__, queue, dispatch_functor, spmm_descr->last_optimized_A_handle,
+ spmm_descr->workspace.get_buffer(), spmm_descr->last_optimized_B_handle,
+ spmm_descr->last_optimized_C_handle);
+ }
+
+ // Release used if USM is used or if the descriptor has been released before spmm_optimize has succeeded
+ sycl::event event = queue.submit([&](sycl::handler& cgh) {
+ cgh.depends_on(dependencies);
+ cgh.host_task(release_functor);
+ });
+ return event;
+}
+
+void spmm_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB,
+ const void* alpha, oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta,
+ oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
+ oneapi::mkl::sparse::spmm_alg alg,
+ oneapi::mkl::sparse::spmm_descr_t spmm_descr, std::size_t& temp_buffer_size) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ detail::check_valid_spmm(__func__, A_view, A_handle, B_handle, C_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ bool is_in_order_queue = queue.is_in_order();
+ auto functor = [=, &temp_buffer_size](sycl::interop_handle ih) {
+ detail::RocsparseScopedContextHandler sc(queue, ih);
+ auto [roc_handle, hip_stream] = sc.get_handle_and_stream(queue);
+ spmm_descr->roc_handle = roc_handle;
+ spmm_descr->hip_stream = hip_stream;
+ auto roc_a = A_handle->backend_handle;
+ auto roc_b = B_handle->backend_handle;
+ auto roc_c = C_handle->backend_handle;
+ auto roc_op_a = detail::get_roc_operation(opA);
+ auto roc_op_b = detail::get_roc_operation(opB);
+ auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = detail::get_roc_spmm_alg(alg);
+ detail::set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ auto status = rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta,
+ roc_c, roc_type, roc_alg, rocsparse_spmm_stage_buffer_size,
+ &temp_buffer_size, nullptr);
+ detail::check_status(status, __func__);
+ detail::synchronize_if_needed(is_in_order_queue, hip_stream);
+ };
+ auto event = detail::dispatch_submit(__func__, queue, functor, A_handle, B_handle, C_handle);
+ event.wait_and_throw();
+ spmm_descr->temp_buffer_size = temp_buffer_size;
+ spmm_descr->buffer_size_called = true;
+}
+
+void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB,
+ const void* alpha, oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta,
+ oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
+ oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr,
+ sycl::buffer workspace) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ if (!A_handle->all_use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ detail::common_spmm_optimize(opA, opB, is_alpha_host_accessible, A_view, A_handle, B_handle,
+ is_beta_host_accessible, C_handle, alg, spmm_descr);
+ // Copy the buffer to extend its lifetime until the descriptor is free'd.
+ spmm_descr->workspace.set_buffer_untyped(workspace);
+ if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
+ return;
+ }
+ std::size_t buffer_size = spmm_descr->temp_buffer_size;
+
+ // The accessor can only be created if the buffer size is greater than 0
+ if (buffer_size > 0) {
+ auto functor = [=](sycl::interop_handle ih, sycl::accessor workspace_acc) {
+ auto roc_handle = spmm_descr->roc_handle;
+ auto workspace_ptr = detail::get_mem(ih, workspace_acc);
+ detail::spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta,
+ C_handle, alg, buffer_size, workspace_ptr,
+ is_alpha_host_accessible);
+ };
+
+ detail::dispatch_submit(__func__, queue, functor, A_handle, workspace, B_handle, C_handle);
+ }
+ else {
+ auto functor = [=](sycl::interop_handle) {
+ auto roc_handle = spmm_descr->roc_handle;
+ detail::spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta,
+ C_handle, alg, buffer_size, nullptr,
+ is_alpha_host_accessible);
+ };
+
+ detail::dispatch_submit(__func__, queue, functor, A_handle, B_handle, C_handle);
+ }
+}
+
+sycl::event spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA,
+ oneapi::mkl::transpose opB, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta,
+ oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
+ oneapi::mkl::sparse::spmm_alg alg,
+ oneapi::mkl::sparse::spmm_descr_t spmm_descr, void* workspace,
+ const std::vector& dependencies) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ if (A_handle->all_use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ detail::common_spmm_optimize(opA, opB, is_alpha_host_accessible, A_view, A_handle, B_handle,
+ is_beta_host_accessible, C_handle, alg, spmm_descr);
+ spmm_descr->workspace.usm_ptr = workspace;
+ if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
+ return detail::collapse_dependencies(queue, dependencies);
+ }
+ std::size_t buffer_size = spmm_descr->temp_buffer_size;
+ auto functor = [=](sycl::interop_handle) {
+ auto roc_handle = spmm_descr->roc_handle;
+ detail::spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, C_handle,
+ alg, buffer_size, workspace, is_alpha_host_accessible);
+ };
+
+ return detail::dispatch_submit(__func__, queue, dependencies, functor, A_handle, B_handle,
+ C_handle);
+}
+
+sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB,
+ const void* alpha, oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta,
+ oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
+ oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr,
+ const std::vector& dependencies) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ if (A_handle->all_use_buffer() != spmm_descr->workspace.use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ if (!spmm_descr->optimized_called) {
+ throw mkl::uninitialized(
+ "sparse_blas", __func__,
+ "spmm_optimize must be called with the same arguments before spmm.");
+ }
+ CHECK_DESCR_MATCH(spmm_descr, opA, "spmm_optimize");
+ CHECK_DESCR_MATCH(spmm_descr, opB, "spmm_optimize");
+ CHECK_DESCR_MATCH(spmm_descr, A_view, "spmm_optimize");
+ CHECK_DESCR_MATCH(spmm_descr, A_handle, "spmm_optimize");
+ CHECK_DESCR_MATCH(spmm_descr, B_handle, "spmm_optimize");
+ CHECK_DESCR_MATCH(spmm_descr, C_handle, "spmm_optimize");
+ CHECK_DESCR_MATCH(spmm_descr, alg, "spmm_optimize");
+ detail::check_valid_spmm(__func__, A_view, A_handle, B_handle, C_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ A_handle->mark_used();
+ auto& buffer_size = spmm_descr->temp_buffer_size;
+ bool is_in_order_queue = queue.is_in_order();
+ auto compute_functor = [=, &buffer_size](void* workspace_ptr) {
+ auto roc_handle = spmm_descr->roc_handle;
+ auto hip_stream = spmm_descr->hip_stream;
+ auto roc_a = A_handle->backend_handle;
+ auto roc_b = B_handle->backend_handle;
+ auto roc_c = C_handle->backend_handle;
+ auto roc_op_a = detail::get_roc_operation(opA);
+ auto roc_op_b = detail::get_roc_operation(opB);
+ auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = detail::get_roc_spmm_alg(alg);
+ detail::set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ auto status = rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta,
+ roc_c, roc_type, roc_alg, rocsparse_spmm_stage_compute,
+ &buffer_size, workspace_ptr);
+ detail::check_status(status, __func__);
+ detail::synchronize_if_needed(is_in_order_queue, hip_stream);
+ };
+ // The accessor can only be created if the buffer size is greater than 0
+ if (A_handle->all_use_buffer() && buffer_size > 0) {
+ auto functor_buffer = [=](sycl::interop_handle ih,
+ sycl::accessor workspace_acc) {
+ auto workspace_ptr = detail::get_mem(ih, workspace_acc);
+ compute_functor(workspace_ptr);
+ };
+ return detail::dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
+ spmm_descr->workspace.get_buffer(),
+ B_handle, C_handle);
+ }
+ else {
+ // The same dispatch_submit can be used for USM or buffers if no
+ // workspace accessor is needed.
+ // workspace_ptr will be a nullptr in the latter case.
+ auto workspace_ptr = spmm_descr->workspace.usm_ptr;
+ auto functor_usm = [=](sycl::interop_handle) {
+ compute_functor(workspace_ptr);
+ };
+ return detail::dispatch_submit_native_ext(__func__, queue, dependencies, functor_usm,
+ A_handle, B_handle, C_handle);
+ }
+}
+
+} // namespace oneapi::mkl::sparse::rocsparse
diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp
new file mode 100644
index 000000000..13f7ed11d
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp
@@ -0,0 +1,350 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp"
+
+#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp"
+#include "sparse_blas/common_op_verification.hpp"
+#include "sparse_blas/macros.hpp"
+#include "sparse_blas/matrix_view_comparison.hpp"
+#include "sparse_blas/sycl_helper.hpp"
+
+namespace oneapi::mkl::sparse {
+
+// Complete the definition of the incomplete type
+struct spmv_descr {
+ // Cache the hipStream_t and global handle to avoid relying on RocsparseScopedContextHandler to retrieve them.
+ hipStream_t hip_stream;
+ rocsparse_handle roc_handle;
+
+ detail::generic_container workspace;
+ std::size_t temp_buffer_size = 0;
+ bool buffer_size_called = false;
+ bool optimized_called = false;
+ oneapi::mkl::transpose last_optimized_opA;
+ oneapi::mkl::sparse::matrix_view last_optimized_A_view;
+ oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle;
+ oneapi::mkl::sparse::dense_vector_handle_t last_optimized_x_handle;
+ oneapi::mkl::sparse::dense_vector_handle_t last_optimized_y_handle;
+ oneapi::mkl::sparse::spmv_alg last_optimized_alg;
+};
+
+} // namespace oneapi::mkl::sparse
+
+namespace oneapi::mkl::sparse::rocsparse {
+
+namespace detail {
+
+inline auto get_roc_spmv_alg(spmv_alg alg) {
+ switch (alg) {
+ case spmv_alg::coo_alg1: return rocsparse_spmv_alg_coo;
+ case spmv_alg::coo_alg2: return rocsparse_spmv_alg_coo_atomic;
+ case spmv_alg::csr_alg1: return rocsparse_spmv_alg_csr_adaptive;
+ case spmv_alg::csr_alg2: return rocsparse_spmv_alg_csr_stream;
+ case spmv_alg::csr_alg3: return rocsparse_spmv_alg_csr_lrb;
+ default: return rocsparse_spmv_alg_default;
+ }
+}
+
+void check_valid_spmv(const std::string& function_name, oneapi::mkl::transpose opA,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ bool is_alpha_host_accessible, bool is_beta_host_accessible) {
+ check_valid_spmv_common(function_name, opA, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ A_handle->check_valid_handle(__func__);
+ if (A_view.type_view != oneapi::mkl::sparse::matrix_descr::general) {
+ throw mkl::unimplemented(
+ "sparse_blas", function_name,
+ "The backend does not support spmv with a `type_view` other than `matrix_descr::general`.");
+ }
+}
+
+inline void common_spmv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_accessible,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ bool is_beta_host_accessible,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spmv_alg alg,
+ oneapi::mkl::sparse::spmv_descr_t spmv_descr) {
+ check_valid_spmv("spmv_optimize", opA, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ if (!spmv_descr->buffer_size_called) {
+ throw mkl::uninitialized(
+ "sparse_blas", "spmv_optimize",
+ "spmv_buffer_size must be called with the same arguments before spmv_optimize.");
+ }
+ spmv_descr->optimized_called = true;
+ spmv_descr->last_optimized_opA = opA;
+ spmv_descr->last_optimized_A_view = A_view;
+ spmv_descr->last_optimized_A_handle = A_handle;
+ spmv_descr->last_optimized_x_handle = x_handle;
+ spmv_descr->last_optimized_y_handle = y_handle;
+ spmv_descr->last_optimized_alg = alg;
+}
+
+void spmv_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void* beta,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spmv_alg alg, std::size_t buffer_size,
+ void* workspace_ptr, bool is_alpha_host_accessible) {
+ auto roc_a = A_handle->backend_handle;
+ auto roc_x = x_handle->backend_handle;
+ auto roc_y = y_handle->backend_handle;
+ auto roc_op = get_roc_operation(opA);
+ auto roc_type = get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = get_roc_spmv_alg(alg);
+ set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ // rocsparse_spmv_stage_preprocess stage is blocking
+ auto status =
+ rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg,
+ rocsparse_spmv_stage_preprocess, &buffer_size, workspace_ptr);
+ check_status(status, "spmv_optimize");
+}
+
+} // namespace detail
+
+void init_spmv_descr(sycl::queue& /*queue*/, spmv_descr_t* p_spmv_descr) {
+ *p_spmv_descr = new spmv_descr();
+}
+
+sycl::event release_spmv_descr(sycl::queue& queue, spmv_descr_t spmv_descr,
+ const std::vector& dependencies) {
+ if (!spmv_descr) {
+ return detail::collapse_dependencies(queue, dependencies);
+ }
+
+ auto release_functor = [=]() {
+ spmv_descr->roc_handle = nullptr;
+ spmv_descr->last_optimized_A_handle = nullptr;
+ spmv_descr->last_optimized_x_handle = nullptr;
+ spmv_descr->last_optimized_y_handle = nullptr;
+ delete spmv_descr;
+ };
+
+ // Use dispatch_submit to ensure the descriptor is kept alive as long as the buffers are used
+ // dispatch_submit can only be used if the descriptor's handles are valid
+ if (spmv_descr->last_optimized_A_handle &&
+ spmv_descr->last_optimized_A_handle->all_use_buffer() &&
+ spmv_descr->last_optimized_x_handle && spmv_descr->last_optimized_y_handle &&
+ spmv_descr->workspace.use_buffer()) {
+ auto dispatch_functor = [=](sycl::interop_handle, sycl::accessor) {
+ release_functor();
+ };
+ return detail::dispatch_submit(
+ __func__, queue, dispatch_functor, spmv_descr->last_optimized_A_handle,
+ spmv_descr->workspace.get_buffer(), spmv_descr->last_optimized_x_handle,
+ spmv_descr->last_optimized_y_handle);
+ }
+
+ // Release used if USM is used or if the descriptor has been released before spmv_optimize has succeeded
+ sycl::event event = queue.submit([&](sycl::handler& cgh) {
+ cgh.depends_on(dependencies);
+ cgh.host_task(release_functor);
+ });
+ return event;
+}
+
+void spmv_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void* beta,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spmv_alg alg,
+ oneapi::mkl::sparse::spmv_descr_t spmv_descr, std::size_t& temp_buffer_size) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ detail::check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+ bool is_in_order_queue = queue.is_in_order();
+ auto functor = [=, &temp_buffer_size](sycl::interop_handle ih) {
+ detail::RocsparseScopedContextHandler sc(queue, ih);
+ auto [roc_handle, hip_stream] = sc.get_handle_and_stream(queue);
+ spmv_descr->roc_handle = roc_handle;
+ spmv_descr->hip_stream = hip_stream;
+ auto roc_a = A_handle->backend_handle;
+ auto roc_x = x_handle->backend_handle;
+ auto roc_y = y_handle->backend_handle;
+ auto roc_op = detail::get_roc_operation(opA);
+ auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = detail::get_roc_spmv_alg(alg);
+ detail::set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ auto status =
+ rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg,
+ rocsparse_spmv_stage_buffer_size, &temp_buffer_size, nullptr);
+ detail::check_status(status, __func__);
+ detail::synchronize_if_needed(is_in_order_queue, hip_stream);
+ };
+ auto event = detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle);
+ event.wait_and_throw();
+ spmv_descr->temp_buffer_size = temp_buffer_size;
+ spmv_descr->buffer_size_called = true;
+}
+
+void spmv_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void* beta,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
+ sycl::buffer workspace) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ if (!A_handle->all_use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ detail::common_spmv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle,
+ is_beta_host_accessible, y_handle, alg, spmv_descr);
+ // Copy the buffer to extend its lifetime until the descriptor is free'd.
+ spmv_descr->workspace.set_buffer_untyped(workspace);
+ if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
+ return;
+ }
+ std::size_t buffer_size = spmv_descr->temp_buffer_size;
+ // The accessor can only be created if the buffer size is greater than 0
+ if (buffer_size > 0) {
+ auto functor = [=](sycl::interop_handle ih, sycl::accessor workspace_acc) {
+ auto roc_handle = spmv_descr->roc_handle;
+ auto workspace_ptr = detail::get_mem(ih, workspace_acc);
+ detail::spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle,
+ alg, buffer_size, workspace_ptr, is_alpha_host_accessible);
+ };
+
+ detail::dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle);
+ }
+ else {
+ auto functor = [=](sycl::interop_handle) {
+ auto roc_handle = spmv_descr->roc_handle;
+ detail::spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle,
+ alg, buffer_size, nullptr, is_alpha_host_accessible);
+ };
+
+ detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle);
+ }
+}
+
+sycl::event spmv_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void* beta,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spmv_alg alg,
+ oneapi::mkl::sparse::spmv_descr_t spmv_descr, void* workspace,
+ const std::vector& dependencies) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ if (A_handle->all_use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ detail::common_spmv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle,
+ is_beta_host_accessible, y_handle, alg, spmv_descr);
+ spmv_descr->workspace.usm_ptr = workspace;
+ if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
+ return detail::collapse_dependencies(queue, dependencies);
+ }
+ std::size_t buffer_size = spmv_descr->temp_buffer_size;
+ auto functor = [=](sycl::interop_handle) {
+ auto roc_handle = spmv_descr->roc_handle;
+ detail::spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, alg,
+ buffer_size, workspace, is_alpha_host_accessible);
+ };
+
+ return detail::dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle,
+ y_handle);
+}
+
+sycl::event spmv(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void* beta,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
+ const std::vector& dependencies) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
+ if (A_handle->all_use_buffer() != spmv_descr->workspace.use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ detail::check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible, is_beta_host_accessible);
+
+ if (!spmv_descr->optimized_called) {
+ throw mkl::uninitialized(
+ "sparse_blas", __func__,
+ "spmv_optimize must be called with the same arguments before spmv.");
+ }
+ CHECK_DESCR_MATCH(spmv_descr, opA, "spmv_optimize");
+ CHECK_DESCR_MATCH(spmv_descr, A_view, "spmv_optimize");
+ CHECK_DESCR_MATCH(spmv_descr, A_handle, "spmv_optimize");
+ CHECK_DESCR_MATCH(spmv_descr, x_handle, "spmv_optimize");
+ CHECK_DESCR_MATCH(spmv_descr, y_handle, "spmv_optimize");
+ CHECK_DESCR_MATCH(spmv_descr, alg, "spmv_optimize");
+
+ A_handle->mark_used();
+ auto& buffer_size = spmv_descr->temp_buffer_size;
+ bool is_in_order_queue = queue.is_in_order();
+ auto compute_functor = [=, &buffer_size](void* workspace_ptr) {
+ auto roc_handle = spmv_descr->roc_handle;
+ auto hip_stream = spmv_descr->hip_stream;
+ auto roc_a = A_handle->backend_handle;
+ auto roc_x = x_handle->backend_handle;
+ auto roc_y = y_handle->backend_handle;
+ auto roc_op = detail::get_roc_operation(opA);
+ auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = detail::get_roc_spmv_alg(alg);
+ detail::set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ auto status =
+ rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg,
+ rocsparse_spmv_stage_compute, &buffer_size, workspace_ptr);
+ detail::check_status(status, __func__);
+ detail::synchronize_if_needed(is_in_order_queue, hip_stream);
+ };
+ // The accessor can only be created if the buffer size is greater than 0
+ if (A_handle->all_use_buffer() && buffer_size > 0) {
+ auto functor_buffer = [=](sycl::interop_handle ih,
+ sycl::accessor workspace_acc) {
+ auto workspace_ptr = detail::get_mem(ih, workspace_acc);
+ compute_functor(workspace_ptr);
+ };
+ return detail::dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
+ spmv_descr->workspace.get_buffer(),
+ x_handle, y_handle);
+ }
+ else {
+ // The same dispatch_submit can be used for USM or buffers if no
+ // workspace accessor is needed.
+ // workspace_ptr will be a nullptr in the latter case.
+ auto workspace_ptr = spmv_descr->workspace.usm_ptr;
+ auto functor_usm = [=](sycl::interop_handle) {
+ compute_functor(workspace_ptr);
+ };
+ return detail::dispatch_submit_native_ext(__func__, queue, dependencies, functor_usm,
+ A_handle, x_handle, y_handle);
+ }
+}
+
+} // namespace oneapi::mkl::sparse::rocsparse
diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp
new file mode 100644
index 000000000..d05afccf0
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp
@@ -0,0 +1,331 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp"
+
+#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp"
+#include "sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp"
+#include "sparse_blas/common_op_verification.hpp"
+#include "sparse_blas/macros.hpp"
+#include "sparse_blas/matrix_view_comparison.hpp"
+#include "sparse_blas/sycl_helper.hpp"
+
+namespace oneapi::mkl::sparse {
+
+// Complete the definition of the incomplete type
+struct spsv_descr {
+ // Cache the hipStream_t and global handle to avoid relying on RocsparseScopedContextHandler to retrieve them.
+ hipStream_t hip_stream;
+ rocsparse_handle roc_handle;
+
+ detail::generic_container workspace;
+ std::size_t temp_buffer_size = 0;
+ bool buffer_size_called = false;
+ bool optimized_called = false;
+ oneapi::mkl::transpose last_optimized_opA;
+ oneapi::mkl::sparse::matrix_view last_optimized_A_view;
+ oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle;
+ oneapi::mkl::sparse::dense_vector_handle_t last_optimized_x_handle;
+ oneapi::mkl::sparse::dense_vector_handle_t last_optimized_y_handle;
+ oneapi::mkl::sparse::spsv_alg last_optimized_alg;
+};
+
+} // namespace oneapi::mkl::sparse
+
+namespace oneapi::mkl::sparse::rocsparse {
+
+namespace detail {
+
+inline auto get_roc_spsv_alg(spsv_alg /*alg*/) {
+ return rocsparse_spsv_alg_default;
+}
+
+void check_valid_spsv(const std::string& function_name, matrix_view A_view,
+ matrix_handle_t A_handle, dense_vector_handle_t x_handle,
+ dense_vector_handle_t y_handle, bool is_alpha_host_accessible) {
+ check_valid_spsv_common(function_name, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible);
+ A_handle->check_valid_handle(function_name);
+}
+
+inline void common_spsv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_accessible,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spsv_alg alg,
+ oneapi::mkl::sparse::spsv_descr_t spsv_descr) {
+ check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible);
+ if (!spsv_descr->buffer_size_called) {
+ throw mkl::uninitialized(
+ "sparse_blas", "spsv_optimize",
+ "spsv_buffer_size must be called with the same arguments before spsv_optimize.");
+ }
+ spsv_descr->optimized_called = true;
+ spsv_descr->last_optimized_opA = opA;
+ spsv_descr->last_optimized_A_view = A_view;
+ spsv_descr->last_optimized_A_handle = A_handle;
+ spsv_descr->last_optimized_x_handle = x_handle;
+ spsv_descr->last_optimized_y_handle = y_handle;
+ spsv_descr->last_optimized_alg = alg;
+}
+
+void spsv_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spsv_alg alg, std::size_t buffer_size,
+ void* workspace_ptr, bool is_alpha_host_accessible) {
+ auto roc_a = A_handle->backend_handle;
+ auto roc_x = x_handle->backend_handle;
+ auto roc_y = y_handle->backend_handle;
+ set_matrix_attributes("spsv_optimize", roc_a, A_view);
+ auto roc_op = get_roc_operation(opA);
+ auto roc_type = get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = get_roc_spsv_alg(alg);
+ set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ // rocsparse_spsv_stage_preprocess stage is blocking
+ auto status = rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg,
+ rocsparse_spsv_stage_preprocess, &buffer_size, workspace_ptr);
+ check_status(status, "spsv_optimize");
+}
+
+} // namespace detail
+
+void init_spsv_descr(sycl::queue& /*queue*/, spsv_descr_t* p_spsv_descr) {
+ *p_spsv_descr = new spsv_descr();
+}
+
+sycl::event release_spsv_descr(sycl::queue& queue, spsv_descr_t spsv_descr,
+ const std::vector& dependencies) {
+ if (!spsv_descr) {
+ return detail::collapse_dependencies(queue, dependencies);
+ }
+
+ auto release_functor = [=]() {
+ spsv_descr->roc_handle = nullptr;
+ spsv_descr->last_optimized_A_handle = nullptr;
+ spsv_descr->last_optimized_x_handle = nullptr;
+ spsv_descr->last_optimized_y_handle = nullptr;
+ delete spsv_descr;
+ };
+
+ // Use dispatch_submit to ensure the descriptor is kept alive as long as the buffers are used
+ // dispatch_submit can only be used if the descriptor's handles are valid
+ if (spsv_descr->last_optimized_A_handle &&
+ spsv_descr->last_optimized_A_handle->all_use_buffer() &&
+ spsv_descr->last_optimized_x_handle && spsv_descr->last_optimized_y_handle &&
+ spsv_descr->workspace.use_buffer()) {
+ auto dispatch_functor = [=](sycl::interop_handle, sycl::accessor) {
+ release_functor();
+ };
+ return detail::dispatch_submit(
+ __func__, queue, dispatch_functor, spsv_descr->last_optimized_A_handle,
+ spsv_descr->workspace.get_buffer(), spsv_descr->last_optimized_x_handle,
+ spsv_descr->last_optimized_y_handle);
+ }
+
+ // Release used if USM is used or if the descriptor has been released before spmv_optimize has succeeded
+ sycl::event event = queue.submit([&](sycl::handler& cgh) {
+ cgh.depends_on(dependencies);
+ cgh.host_task(release_functor);
+ });
+ return event;
+}
+
+void spsv_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spsv_alg alg,
+ oneapi::mkl::sparse::spsv_descr_t spsv_descr, std::size_t& temp_buffer_size) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ detail::check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible);
+ bool is_in_order_queue = queue.is_in_order();
+ auto functor = [=, &temp_buffer_size](sycl::interop_handle ih) {
+ detail::RocsparseScopedContextHandler sc(queue, ih);
+ auto [roc_handle, hip_stream] = sc.get_handle_and_stream(queue);
+ spsv_descr->roc_handle = roc_handle;
+ spsv_descr->hip_stream = hip_stream;
+ auto roc_a = A_handle->backend_handle;
+ auto roc_x = x_handle->backend_handle;
+ auto roc_y = y_handle->backend_handle;
+ detail::set_matrix_attributes(__func__, roc_a, A_view);
+ auto roc_op = detail::get_roc_operation(opA);
+ auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = detail::get_roc_spsv_alg(alg);
+ detail::set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ auto status =
+ rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg,
+ rocsparse_spsv_stage_buffer_size, &temp_buffer_size, nullptr);
+ detail::check_status(status, __func__);
+ detail::synchronize_if_needed(is_in_order_queue, hip_stream);
+ };
+ auto event = detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle);
+ event.wait_and_throw();
+ spsv_descr->temp_buffer_size = temp_buffer_size;
+ spsv_descr->buffer_size_called = true;
+}
+
+void spsv_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr,
+ sycl::buffer workspace) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ if (!A_handle->all_use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ A_handle->check_valid_handle(__func__);
+ detail::common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle,
+ y_handle, alg, spsv_descr);
+ // Ignore spsv_alg::no_optimize_alg as this step is mandatory for rocSPARSE
+ // Copy the buffer to extend its lifetime until the descriptor is free'd.
+ spsv_descr->workspace.set_buffer_untyped(workspace);
+ std::size_t buffer_size = spsv_descr->temp_buffer_size;
+ // The accessor can only be created if the buffer size is greater than 0
+ if (buffer_size > 0) {
+ auto functor = [=](sycl::interop_handle ih, sycl::accessor workspace_acc) {
+ auto roc_handle = spsv_descr->roc_handle;
+ auto workspace_ptr = detail::get_mem(ih, workspace_acc);
+ detail::spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle,
+ alg, buffer_size, workspace_ptr, is_alpha_host_accessible);
+ };
+
+ detail::dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle);
+ }
+ else {
+ auto functor = [=](sycl::interop_handle) {
+ auto roc_handle = spsv_descr->roc_handle;
+ detail::spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle,
+ alg, buffer_size, nullptr, is_alpha_host_accessible);
+ };
+
+ detail::dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle);
+ }
+}
+
+sycl::event spsv_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spsv_alg alg,
+ oneapi::mkl::sparse::spsv_descr_t spsv_descr, void* workspace,
+ const std::vector& dependencies) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ if (A_handle->all_use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ A_handle->check_valid_handle(__func__);
+ detail::common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle,
+ y_handle, alg, spsv_descr);
+ spsv_descr->workspace.usm_ptr = workspace;
+ // Ignore spsv_alg::no_optimize_alg as this step is mandatory for rocSPARSE
+ std::size_t buffer_size = spsv_descr->temp_buffer_size;
+ auto functor = [=](sycl::interop_handle) {
+ auto roc_handle = spsv_descr->roc_handle;
+ detail::spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle,
+ alg, buffer_size, workspace, is_alpha_host_accessible);
+ };
+
+ return detail::dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle,
+ y_handle);
+}
+
+sycl::event spsv(sycl::queue& queue, oneapi::mkl::transpose opA, const void* alpha,
+ oneapi::mkl::sparse::matrix_view A_view,
+ oneapi::mkl::sparse::matrix_handle_t A_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t x_handle,
+ oneapi::mkl::sparse::dense_vector_handle_t y_handle,
+ oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr,
+ const std::vector& dependencies) {
+ bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
+ if (A_handle->all_use_buffer() != spsv_descr->workspace.use_buffer()) {
+ detail::throw_incompatible_container(__func__);
+ }
+ detail::check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle,
+ is_alpha_host_accessible);
+
+ if (!spsv_descr->optimized_called) {
+ throw mkl::uninitialized(
+ "sparse_blas", __func__,
+ "spsv_optimize must be called with the same arguments before spsv.");
+ }
+ CHECK_DESCR_MATCH(spsv_descr, opA, "spsv_optimize");
+ CHECK_DESCR_MATCH(spsv_descr, A_view, "spsv_optimize");
+ CHECK_DESCR_MATCH(spsv_descr, A_handle, "spsv_optimize");
+ CHECK_DESCR_MATCH(spsv_descr, x_handle, "spsv_optimize");
+ CHECK_DESCR_MATCH(spsv_descr, y_handle, "spsv_optimize");
+ CHECK_DESCR_MATCH(spsv_descr, alg, "spsv_optimize");
+
+ A_handle->mark_used();
+ auto& buffer_size = spsv_descr->temp_buffer_size;
+ bool is_in_order_queue = queue.is_in_order();
+ auto compute_functor = [=, &buffer_size](void* workspace_ptr) {
+ auto roc_handle = spsv_descr->roc_handle;
+ auto hip_stream = spsv_descr->hip_stream;
+ auto roc_a = A_handle->backend_handle;
+ auto roc_x = x_handle->backend_handle;
+ auto roc_y = y_handle->backend_handle;
+ detail::set_matrix_attributes(__func__, roc_a, A_view);
+ auto roc_op = detail::get_roc_operation(opA);
+ auto roc_type = detail::get_roc_value_type(A_handle->value_container.data_type);
+ auto roc_alg = detail::get_roc_spsv_alg(alg);
+ detail::set_pointer_mode(roc_handle, is_alpha_host_accessible);
+ auto status =
+ rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg,
+ rocsparse_spsv_stage_compute, &buffer_size, workspace_ptr);
+ detail::check_status(status, __func__);
+ detail::synchronize_if_needed(is_in_order_queue, hip_stream);
+ };
+ // The accessor can only be created if the buffer size is greater than 0
+ if (A_handle->all_use_buffer() && buffer_size > 0) {
+ auto functor_buffer = [=](sycl::interop_handle ih,
+ sycl::accessor workspace_acc) {
+ auto workspace_ptr = detail::get_mem(ih, workspace_acc);
+ compute_functor(workspace_ptr);
+ };
+ return detail::dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
+ spsv_descr->workspace.get_buffer(),
+ x_handle, y_handle);
+ }
+ else {
+ // The same dispatch_submit can be used for USM or buffers if no
+ // workspace accessor is needed.
+ // workspace_ptr will be a nullptr in the latter case.
+ auto workspace_ptr = spsv_descr->workspace.usm_ptr;
+ auto functor_usm = [=](sycl::interop_handle) {
+ compute_functor(workspace_ptr);
+ };
+ return detail::dispatch_submit_native_ext(__func__, queue, dependencies, functor_usm,
+ A_handle, x_handle, y_handle);
+ }
+}
+
+} // namespace oneapi::mkl::sparse::rocsparse
diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp
new file mode 100644
index 000000000..cd3191ad2
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp
@@ -0,0 +1,126 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_
+#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_
+
+#include
+
+#include
+#include
+
+#include "oneapi/mkl/exceptions.hpp"
+
+namespace oneapi::mkl::sparse::rocsparse::detail {
+
+inline std::string hip_result_to_str(hipError_t result) {
+ switch (result) {
+#define ONEMKL_ROCSPARSE_CASE(STATUS) \
+ case STATUS: return #STATUS
+ ONEMKL_ROCSPARSE_CASE(hipSuccess);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidContext);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidKernelFile);
+ ONEMKL_ROCSPARSE_CASE(hipErrorMemoryAllocation);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInitializationError);
+ ONEMKL_ROCSPARSE_CASE(hipErrorLaunchFailure);
+ ONEMKL_ROCSPARSE_CASE(hipErrorLaunchOutOfResources);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidDevice);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidValue);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidDevicePointer);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidMemcpyDirection);
+ ONEMKL_ROCSPARSE_CASE(hipErrorUnknown);
+ ONEMKL_ROCSPARSE_CASE(hipErrorInvalidResourceHandle);
+ ONEMKL_ROCSPARSE_CASE(hipErrorNotReady);
+ ONEMKL_ROCSPARSE_CASE(hipErrorNoDevice);
+ ONEMKL_ROCSPARSE_CASE(hipErrorPeerAccessAlreadyEnabled);
+ ONEMKL_ROCSPARSE_CASE(hipErrorPeerAccessNotEnabled);
+ ONEMKL_ROCSPARSE_CASE(hipErrorRuntimeMemory);
+ ONEMKL_ROCSPARSE_CASE(hipErrorRuntimeOther);
+ ONEMKL_ROCSPARSE_CASE(hipErrorHostMemoryAlreadyRegistered);
+ ONEMKL_ROCSPARSE_CASE(hipErrorHostMemoryNotRegistered);
+ ONEMKL_ROCSPARSE_CASE(hipErrorMapBufferObjectFailed);
+ ONEMKL_ROCSPARSE_CASE(hipErrorTbd);
+ default: return "";
+ }
+}
+
+#define HIP_ERROR_FUNC(func, ...) \
+ do { \
+ auto res = func(__VA_ARGS__); \
+ if (res != hipSuccess) { \
+ throw oneapi::mkl::exception("sparse_blas", #func, \
+ "hip error: " + detail::hip_result_to_str(res)); \
+ } \
+ } while (0)
+
+inline std::string rocsparse_status_to_str(rocsparse_status status) {
+ switch (status) {
+#define ONEMKL_ROCSPARSE_CASE(STATUS) \
+ case STATUS: return #STATUS
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_success);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_handle);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_not_implemented);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_pointer);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_size);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_memory_error);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_internal_error);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_value);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_arch_mismatch);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_zero_pivot);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_not_initialized);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_type_mismatch);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_requires_sorted_storage);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_thrown_exception);
+ ONEMKL_ROCSPARSE_CASE(rocsparse_status_continue);
+#undef ONEMKL_ROCSPARSE_CASE
+ default: return "";
+ }
+}
+
+inline void check_status(rocsparse_status status, const std::string& function,
+ std::string error_str = "") {
+ if (status != rocsparse_status_success) {
+ if (!error_str.empty()) {
+ error_str += "; ";
+ }
+ error_str += "rocSPARSE status: " + rocsparse_status_to_str(status);
+ switch (status) {
+ case rocsparse_status_not_implemented:
+ throw oneapi::mkl::unimplemented("sparse_blas", function, error_str);
+ case rocsparse_status_invalid_handle:
+ case rocsparse_status_invalid_pointer:
+ case rocsparse_status_invalid_size:
+ case rocsparse_status_invalid_value:
+ throw oneapi::mkl::invalid_argument("sparse_blas", function, error_str);
+ case rocsparse_status_not_initialized:
+ throw oneapi::mkl::uninitialized("sparse_blas", function, error_str);
+ default: throw oneapi::mkl::exception("sparse_blas", function, error_str);
+ }
+ }
+}
+
+#define ROCSPARSE_ERR_FUNC(func, ...) \
+ do { \
+ auto status = func(__VA_ARGS__); \
+ detail::check_status(status, #func); \
+ } while (0)
+
+} // namespace oneapi::mkl::sparse::rocsparse::detail
+
+#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_
diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp
new file mode 100644
index 000000000..bba2b5b1d
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp
@@ -0,0 +1,63 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_
+#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_
+
+/**
+ * @file Similar to blas_handle.hpp
+ * Provides a map from a ur_context_handle_t (or equivalent) to a rocsparse_handle.
+ * @see rocsparse_scope_handle.hpp
+*/
+
+#include
+#include
+
+namespace oneapi::mkl::sparse::rocsparse::detail {
+
+template
+struct rocsparse_global_handle {
+ using handle_container_t = std::unordered_map*>;
+ handle_container_t rocsparse_global_handle_mapper_{};
+
+ ~rocsparse_global_handle() noexcept(false) {
+ for (auto& handle_pair : rocsparse_global_handle_mapper_) {
+ if (handle_pair.second != nullptr) {
+ auto handle = handle_pair.second->exchange(nullptr);
+ if (handle != nullptr) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_handle, handle);
+ handle = nullptr;
+ }
+ else {
+ // if the handle is nullptr it means the handle was already
+ // destroyed by the ContextCallback and we're free to delete the
+ // atomic object.
+ delete handle_pair.second;
+ }
+
+ handle_pair.second = nullptr;
+ }
+ }
+ rocsparse_global_handle_mapper_.clear();
+ }
+};
+
+} // namespace oneapi::mkl::sparse::rocsparse::detail
+
+#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_
diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp
new file mode 100644
index 000000000..bfa23dfbd
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp
@@ -0,0 +1,491 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp"
+
+#include "rocsparse_error.hpp"
+#include "rocsparse_helper.hpp"
+#include "rocsparse_handles.hpp"
+#include "rocsparse_scope_handle.hpp"
+#include "rocsparse_task.hpp"
+#include "sparse_blas/macros.hpp"
+
+namespace oneapi::mkl::sparse::rocsparse {
+
+/**
+ * In this file RocsparseScopedContextHandler are used to ensure that a rocsparse_handle is created before any other rocSPARSE call, as required by the specification.
+*/
+
+// Dense vector
+template
+void init_dense_vector(sycl::queue& queue, dense_vector_handle_t* p_dvhandle, std::int64_t size,
+ sycl::buffer val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_value_type = detail::RocEnumType::value;
+ rocsparse_dnvec_descr roc_dvhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &roc_dvhandle, size,
+ detail::get_mem(ih, acc), roc_value_type);
+ *p_dvhandle = new dense_vector_handle(roc_dvhandle, val, size);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void init_dense_vector(sycl::queue& queue, dense_vector_handle_t* p_dvhandle, std::int64_t size,
+ fpType* val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_value_type = detail::RocEnumType::value;
+ rocsparse_dnvec_descr roc_dvhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &roc_dvhandle, size, val,
+ roc_value_type);
+ *p_dvhandle = new dense_vector_handle(roc_dvhandle, val, size);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_dense_vector_data(sycl::queue& queue, dense_vector_handle_t dvhandle, std::int64_t size,
+ sycl::buffer val) {
+ detail::check_can_reset_value_handle(__func__, dvhandle, true);
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ if (dvhandle->size != size) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle);
+ auto roc_value_type = detail::RocEnumType::value;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &dvhandle->backend_handle, size,
+ detail::get_mem(ih, acc), roc_value_type);
+ dvhandle->size = size;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_dnvec_set_values, dvhandle->backend_handle,
+ detail::get_mem(ih, acc));
+ }
+ dvhandle->set_buffer(val);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_dense_vector_data(sycl::queue&, dense_vector_handle_t dvhandle, std::int64_t size,
+ fpType* val) {
+ detail::check_can_reset_value_handle(__func__, dvhandle, false);
+ if (dvhandle->size != size) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle);
+ auto roc_value_type = detail::RocEnumType::value;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &dvhandle->backend_handle, size, val,
+ roc_value_type);
+ dvhandle->size = size;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_dnvec_set_values, dvhandle->backend_handle, val);
+ }
+ dvhandle->set_usm_ptr(val);
+}
+
+FOR_EACH_FP_TYPE(INSTANTIATE_DENSE_VECTOR_FUNCS);
+
+sycl::event release_dense_vector(sycl::queue& queue, dense_vector_handle_t dvhandle,
+ const std::vector& dependencies) {
+ // Use dispatch_submit_impl_fp to ensure the backend's handle is kept alive as long as the buffer is used
+ auto functor = [=](sycl::interop_handle) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle);
+ delete dvhandle;
+ };
+ return detail::dispatch_submit_impl_fp(__func__, queue, dependencies, functor, dvhandle);
+}
+
+// Dense matrix
+template
+void init_dense_matrix(sycl::queue& queue, dense_matrix_handle_t* p_dmhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t ld, layout dense_layout,
+ sycl::buffer val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_value_type = detail::RocEnumType::value;
+ auto roc_order = detail::get_roc_order(dense_layout);
+ rocsparse_dnmat_descr roc_dmhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &roc_dmhandle, num_rows, num_cols, ld,
+ detail::get_mem(ih, acc), roc_value_type, roc_order);
+ *p_dmhandle =
+ new dense_matrix_handle(roc_dmhandle, val, num_rows, num_cols, ld, dense_layout);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void init_dense_matrix(sycl::queue& queue, dense_matrix_handle_t* p_dmhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t ld, layout dense_layout, fpType* val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_value_type = detail::RocEnumType::value;
+ auto roc_order = detail::get_roc_order(dense_layout);
+ rocsparse_dnmat_descr roc_dmhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &roc_dmhandle, num_rows, num_cols, ld,
+ val, roc_value_type, roc_order);
+ *p_dmhandle =
+ new dense_matrix_handle(roc_dmhandle, val, num_rows, num_cols, ld, dense_layout);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_dense_matrix_data(sycl::queue& queue, dense_matrix_handle_t dmhandle,
+ std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld,
+ oneapi::mkl::layout dense_layout, sycl::buffer val) {
+ detail::check_can_reset_value_handle(__func__, dmhandle, true);
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ if (dmhandle->num_rows != num_rows || dmhandle->num_cols != num_cols ||
+ dmhandle->ld != ld || dmhandle->dense_layout != dense_layout) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle);
+ auto roc_value_type = detail::RocEnumType::value;
+ auto roc_order = detail::get_roc_order(dense_layout);
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &dmhandle->backend_handle,
+ num_rows, num_cols, ld, detail::get_mem(ih, acc), roc_value_type,
+ roc_order);
+ dmhandle->num_rows = num_rows;
+ dmhandle->num_cols = num_cols;
+ dmhandle->ld = ld;
+ dmhandle->dense_layout = dense_layout;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_dnmat_set_values, dmhandle->backend_handle,
+ detail::get_mem(ih, acc));
+ }
+ dmhandle->set_buffer(val);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_dense_matrix_data(sycl::queue&, dense_matrix_handle_t dmhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t ld, oneapi::mkl::layout dense_layout,
+ fpType* val) {
+ detail::check_can_reset_value_handle(__func__, dmhandle, false);
+ if (dmhandle->num_rows != num_rows || dmhandle->num_cols != num_cols || dmhandle->ld != ld ||
+ dmhandle->dense_layout != dense_layout) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle);
+ auto roc_value_type = detail::RocEnumType::value;
+ auto roc_order = detail::get_roc_order(dense_layout);
+ ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &dmhandle->backend_handle, num_rows,
+ num_cols, ld, val, roc_value_type, roc_order);
+ dmhandle->num_rows = num_rows;
+ dmhandle->num_cols = num_cols;
+ dmhandle->ld = ld;
+ dmhandle->dense_layout = dense_layout;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_dnmat_set_values, dmhandle->backend_handle, val);
+ }
+ dmhandle->set_usm_ptr(val);
+}
+
+FOR_EACH_FP_TYPE(INSTANTIATE_DENSE_MATRIX_FUNCS);
+
+sycl::event release_dense_matrix(sycl::queue& queue, dense_matrix_handle_t dmhandle,
+ const std::vector& dependencies) {
+ // Use dispatch_submit_impl_fp to ensure the backend's handle is kept alive as long as the buffer is used
+ auto functor = [=](sycl::interop_handle) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle);
+ delete dmhandle;
+ };
+ return detail::dispatch_submit_impl_fp(__func__, queue, dependencies, functor, dmhandle);
+}
+
+// COO matrix
+template
+void init_coo_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ sycl::buffer row_ind, sycl::buffer col_ind,
+ sycl::buffer val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto row_acc = row_ind.template get_access(cgh);
+ auto col_acc = col_ind.template get_access(cgh);
+ auto val_acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ rocsparse_spmat_descr roc_smhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &roc_smhandle, num_rows, num_cols, nnz,
+ detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc),
+ detail::get_mem(ih, val_acc), roc_index_type, roc_index_base,
+ roc_value_type);
+ *p_smhandle =
+ new matrix_handle(roc_smhandle, row_ind, col_ind, val, detail::sparse_format::COO,
+ num_rows, num_cols, nnz, index);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void init_coo_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ intType* row_ind, intType* col_ind, fpType* val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ rocsparse_spmat_descr roc_smhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &roc_smhandle, num_rows, num_cols, nnz,
+ row_ind, col_ind, val, roc_index_type, roc_index_base,
+ roc_value_type);
+ *p_smhandle =
+ new matrix_handle(roc_smhandle, row_ind, col_ind, val, detail::sparse_format::COO,
+ num_rows, num_cols, nnz, index);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_coo_matrix_data(sycl::queue& queue, matrix_handle_t smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ sycl::buffer row_ind, sycl::buffer col_ind,
+ sycl::buffer val) {
+ detail::check_can_reset_sparse_handle(__func__, smhandle, true);
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto row_acc = row_ind.template get_access(cgh);
+ auto col_acc = col_ind.template get_access(cgh);
+ auto val_acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols ||
+ smhandle->nnz != nnz || smhandle->index != index) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &smhandle->backend_handle, num_rows,
+ num_cols, nnz, detail::get_mem(ih, row_acc),
+ detail::get_mem(ih, col_acc), detail::get_mem(ih, val_acc),
+ roc_index_type, roc_index_base, roc_value_type);
+ smhandle->num_rows = num_rows;
+ smhandle->num_cols = num_cols;
+ smhandle->nnz = nnz;
+ smhandle->index = index;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_coo_set_pointers, smhandle->backend_handle,
+ detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc),
+ detail::get_mem(ih, val_acc));
+ }
+ smhandle->row_container.set_buffer(row_ind);
+ smhandle->col_container.set_buffer(col_ind);
+ smhandle->value_container.set_buffer(val);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_coo_matrix_data(sycl::queue&, matrix_handle_t smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ intType* row_ind, intType* col_ind, fpType* val) {
+ detail::check_can_reset_sparse_handle(__func__, smhandle, false);
+ if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || smhandle->nnz != nnz ||
+ smhandle->index != index) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &smhandle->backend_handle, num_rows,
+ num_cols, nnz, row_ind, col_ind, val, roc_index_type, roc_index_base,
+ roc_value_type);
+ smhandle->num_rows = num_rows;
+ smhandle->num_cols = num_cols;
+ smhandle->nnz = nnz;
+ smhandle->index = index;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_coo_set_pointers, smhandle->backend_handle, row_ind, col_ind,
+ val);
+ }
+ smhandle->row_container.set_usm_ptr(row_ind);
+ smhandle->col_container.set_usm_ptr(col_ind);
+ smhandle->value_container.set_usm_ptr(val);
+}
+
+FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_COO_MATRIX_FUNCS);
+
+// CSR matrix
+template
+void init_csr_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ sycl::buffer row_ptr, sycl::buffer col_ind,
+ sycl::buffer val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto row_acc = row_ptr.template get_access(cgh);
+ auto col_acc = col_ind.template get_access(cgh);
+ auto val_acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ rocsparse_spmat_descr roc_smhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &roc_smhandle, num_rows, num_cols, nnz,
+ detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc),
+ detail::get_mem(ih, val_acc), roc_index_type, roc_index_type,
+ roc_index_base, roc_value_type);
+ *p_smhandle =
+ new matrix_handle(roc_smhandle, row_ptr, col_ind, val, detail::sparse_format::CSR,
+ num_rows, num_cols, nnz, index);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void init_csr_matrix(sycl::queue& queue, matrix_handle_t* p_smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ intType* row_ptr, intType* col_ind, fpType* val) {
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ // Ensure that a rocsparse handle is created before any other rocSPARSE function is called.
+ detail::RocsparseScopedContextHandler(queue, ih).get_handle(queue);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ rocsparse_spmat_descr roc_smhandle;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &roc_smhandle, num_rows, num_cols, nnz,
+ row_ptr, col_ind, val, roc_index_type, roc_index_type,
+ roc_index_base, roc_value_type);
+ *p_smhandle =
+ new matrix_handle(roc_smhandle, row_ptr, col_ind, val, detail::sparse_format::CSR,
+ num_rows, num_cols, nnz, index);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_csr_matrix_data(sycl::queue& queue, matrix_handle_t smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ sycl::buffer row_ptr, sycl::buffer col_ind,
+ sycl::buffer val) {
+ detail::check_can_reset_sparse_handle(__func__, smhandle, true);
+ auto event = queue.submit([&](sycl::handler& cgh) {
+ auto row_acc = row_ptr.template get_access(cgh);
+ auto col_acc = col_ind.template get_access(cgh);
+ auto val_acc = val.template get_access(cgh);
+ detail::submit_host_task(cgh, queue, [=](sycl::interop_handle ih) {
+ if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols ||
+ smhandle->nnz != nnz || smhandle->index != index) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &smhandle->backend_handle, num_rows,
+ num_cols, nnz, detail::get_mem(ih, row_acc),
+ detail::get_mem(ih, col_acc), detail::get_mem(ih, val_acc),
+ roc_index_type, roc_index_type, roc_index_base, roc_value_type);
+ smhandle->num_rows = num_rows;
+ smhandle->num_cols = num_cols;
+ smhandle->nnz = nnz;
+ smhandle->index = index;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_csr_set_pointers, smhandle->backend_handle,
+ detail::get_mem(ih, row_acc), detail::get_mem(ih, col_acc),
+ detail::get_mem(ih, val_acc));
+ }
+ smhandle->row_container.set_buffer(row_ptr);
+ smhandle->col_container.set_buffer(col_ind);
+ smhandle->value_container.set_buffer(val);
+ });
+ });
+ event.wait_and_throw();
+}
+
+template
+void set_csr_matrix_data(sycl::queue&, matrix_handle_t smhandle, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index,
+ intType* row_ptr, intType* col_ind, fpType* val) {
+ detail::check_can_reset_sparse_handle(__func__, smhandle, false);
+ if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || smhandle->nnz != nnz ||
+ smhandle->index != index) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle);
+ auto roc_index_type = detail::RocIndexEnumType::value;
+ auto roc_index_base = detail::get_roc_index_base(index);
+ auto roc_value_type = detail::RocEnumType::value;
+ ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &smhandle->backend_handle, num_rows,
+ num_cols, nnz, row_ptr, col_ind, val, roc_index_type, roc_index_type,
+ roc_index_base, roc_value_type);
+ smhandle->num_rows = num_rows;
+ smhandle->num_cols = num_cols;
+ smhandle->nnz = nnz;
+ smhandle->index = index;
+ }
+ else {
+ ROCSPARSE_ERR_FUNC(rocsparse_csr_set_pointers, smhandle->backend_handle, row_ptr, col_ind,
+ val);
+ }
+ smhandle->row_container.set_usm_ptr(row_ptr);
+ smhandle->col_container.set_usm_ptr(col_ind);
+ smhandle->value_container.set_usm_ptr(val);
+}
+
+FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_CSR_MATRIX_FUNCS);
+
+sycl::event release_sparse_matrix(sycl::queue& queue, matrix_handle_t smhandle,
+ const std::vector& dependencies) {
+ // Use dispatch_submit to ensure the backend's handle is kept alive as long as the buffers are used
+ auto functor = [=](sycl::interop_handle) {
+ ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle);
+ delete smhandle;
+ };
+ return detail::dispatch_submit(__func__, queue, dependencies, functor, smhandle);
+}
+
+// Matrix property
+bool set_matrix_property(sycl::queue&, matrix_handle_t smhandle, matrix_property property) {
+ // No equivalent in rocSPARSE
+ // Store the matrix property internally for future usages
+ smhandle->set_matrix_property(property);
+ return false;
+}
+
+} // namespace oneapi::mkl::sparse::rocsparse
diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp
new file mode 100644
index 000000000..feccec96f
--- /dev/null
+++ b/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp
@@ -0,0 +1,105 @@
+/***************************************************************************
+* Copyright (C) Codeplay Software Limited
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* For your convenience, a copy of the License has been included in this
+* repository.
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*
+**************************************************************************/
+
+#ifndef _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_
+#define _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_
+
+#include
+
+#include "sparse_blas/generic_container.hpp"
+
+namespace oneapi::mkl::sparse {
+
+// Complete the definition of incomplete types dense_vector_handle, dense_matrix_handle and matrix_handle.
+
+struct dense_vector_handle : public detail::generic_dense_vector_handle {
+ template
+ dense_vector_handle(rocsparse_dnvec_descr roc_descr, T* value_ptr, std::int64_t size)
+ : detail::generic_dense_vector_handle(roc_descr, value_ptr,
+ size) {}
+
+ template
+ dense_vector_handle(rocsparse_dnvec_descr roc_descr, const sycl::buffer value_buffer,
+ std::int64_t size)
+ : detail::generic_dense_vector_handle(roc_descr, value_buffer,
+ size) {}
+};
+
+struct dense_matrix_handle : public detail::generic_dense_matrix_handle {
+ template
+ dense_matrix_handle(rocsparse_dnmat_descr roc_descr, T* value_ptr, std::int64_t num_rows,
+ std::int64_t num_cols, std::int64_t ld, layout dense_layout)
+ : detail::generic_dense_matrix_handle(
+ roc_descr, value_ptr, num_rows, num_cols, ld, dense_layout) {}
+
+ template
+ dense_matrix_handle(rocsparse_dnmat_descr roc_descr, const sycl::buffer value_buffer,
+ std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld,
+ layout dense_layout)
+ : detail::generic_dense_matrix_handle(
+ roc_descr, value_buffer, num_rows, num_cols, ld, dense_layout) {}
+};
+
+struct matrix_handle : public detail::generic_sparse_handle {
+ // A matrix handle should only be used once per operation to be safe with the rocSPARSE backend.
+ // An operation can store information in the handle. See details in https://github.com/ROCm/rocSPARSE/issues/332.
+private:
+ bool used = false;
+
+public:
+ template