diff --git a/src/dft/backends/cufft/backward.cpp b/src/dft/backends/cufft/backward.cpp index 693ad4d1b..80e475991 100644 --- a/src/dft/backends/cufft/backward.cpp +++ b/src/dft/backends/cufft/backward.cpp @@ -30,6 +30,7 @@ #include "oneapi/mkl/dft/types.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include @@ -71,7 +72,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto inout_native = reinterpret_cast *>( @@ -117,7 +118,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_acc = out.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto in_native = reinterpret_cast( @@ -171,7 +172,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( @@ -217,7 +218,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( diff --git a/src/dft/backends/cufft/execute_helper.hpp b/src/dft/backends/cufft/execute_helper.hpp index 644cf7148..bbe32c146 100644 --- a/src/dft/backends/cufft/execute_helper.hpp +++ b/src/dft/backends/cufft/execute_helper.hpp @@ -147,27 +147,6 @@ inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, c return stream; } - -/** Wrap interop API to launch interop host task. - * - * @tparam HandlerT The command group handler type - * @tparam FnT The body of the enqueued task - * - * Either uses host task interop API, or enqueue native command extension. - * This extension avoids host synchronization after - * the CUDA call is complete. - */ -template -static inline void cufft_enqueue_task(HandlerT&& cgh, FnT&& f) { -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ -#else - cgh.host_task([=](sycl::interop_handle ih){ -#endif - f(std::move(ih)); - }); -} - } // namespace oneapi::mkl::dft::cufft::detail #endif diff --git a/src/dft/backends/cufft/forward.cpp b/src/dft/backends/cufft/forward.cpp index bdbda2cb5..7cf73976d 100644 --- a/src/dft/backends/cufft/forward.cpp +++ b/src/dft/backends/cufft/forward.cpp @@ -31,6 +31,7 @@ #include "oneapi/mkl/dft/types.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include @@ -74,7 +75,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto inout_native = reinterpret_cast *>( @@ -119,7 +120,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto in_native = reinterpret_cast( @@ -173,7 +174,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( @@ -219,7 +220,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( diff --git a/src/dft/backends/rocfft/backward.cpp b/src/dft/backends/rocfft/backward.cpp index bdb1c9638..e76437ee2 100644 --- a/src/dft/backends/rocfft/backward.cpp +++ b/src/dft/backends/rocfft/backward.cpp @@ -29,6 +29,7 @@ #include "oneapi/mkl/dft/descriptor.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include "rocfft_handle.hpp" #include @@ -78,7 +79,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); auto inout_native = reinterpret_cast( @@ -112,7 +113,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto inout_im_acc = inout_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ @@ -146,7 +147,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_acc = out.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in, out)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -181,7 +182,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_im_acc = out_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -235,7 +236,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); void *inout_ptr = inout; @@ -268,7 +269,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ inout_re + offsets[0], inout_im + offsets[0] }; @@ -300,7 +301,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in, out, deps)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -330,7 +331,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im, deps)"; auto stream = detail::setup_stream(func_name, ih, info); diff --git a/src/dft/backends/rocfft/execute_helper.hpp b/src/dft/backends/rocfft/execute_helper.hpp index 49c499637..a182546b5 100644 --- a/src/dft/backends/rocfft/execute_helper.hpp +++ b/src/dft/backends/rocfft/execute_helper.hpp @@ -98,26 +98,6 @@ inline void execute_checked(const std::string &func, hipStream_t stream, const r #endif } -/** Wrap interop API to launch interop host task. - * - * @tparam HandlerT The command group handler type - * @tparam FnT The body of the enqueued task - * - * Either uses host task interop API, or enqueue native command extension. - * This extension avoids host synchronization after - * the CUDA call is complete. - */ -template -static inline void rocfft_enqueue_task(HandlerT&& cgh, FnT&& f) { -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ -#else - cgh.host_task([=](sycl::interop_handle ih){ -#endif - f(std::move(ih)); - }); -} - } // namespace oneapi::mkl::dft::rocfft::detail #endif diff --git a/src/dft/backends/rocfft/forward.cpp b/src/dft/backends/rocfft/forward.cpp index daacc685d..d9a576720 100644 --- a/src/dft/backends/rocfft/forward.cpp +++ b/src/dft/backends/rocfft/forward.cpp @@ -30,6 +30,7 @@ #include "oneapi/mkl/dft/descriptor.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include "rocfft_handle.hpp" #include @@ -81,7 +82,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); auto inout_native = reinterpret_cast( @@ -115,7 +116,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto inout_im_acc = inout_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ @@ -148,7 +149,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in, out)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -183,7 +184,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto out_im_acc = out_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -237,7 +238,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); void *inout_ptr = inout; @@ -269,7 +270,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ inout_re + offsets[0], inout_im + offsets[0] }; @@ -300,7 +301,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in, out, deps)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -330,7 +331,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im, deps)"; auto stream = detail::setup_stream(func_name, ih, info); diff --git a/src/dft/execute_helper_generic.hpp b/src/dft/execute_helper_generic.hpp new file mode 100644 index 000000000..519f6fda6 --- /dev/null +++ b/src/dft/execute_helper_generic.hpp @@ -0,0 +1,53 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_CUFFT_EXECUTE_GENERIC_HPP_ +#define _ONEMKL_DFT_SRC_CUFFT_EXECUTE_GENERIC_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +namespace oneapi::mkl::dft::detail { + +/** Wrap interop API to launch interop host task. + * + * @tparam HandlerT The command group handler type + * @tparam FnT The body of the enqueued task + * + * Either uses host task interop API, or enqueue native command extension. + * This extension avoids host synchronization after + * the native call is complete. + */ +template +static inline void fft_enqueue_task(HandlerT&& cgh, FnT&& f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ +#else + cgh.host_task([=](sycl::interop_handle ih){ +#endif + f(std::move(ih)); + }); +} + +} // namespace oneapi::mkl::dft::detail + +#endif