Skip to content

Commit

Permalink
[rocfft][cufft] DFT update host task to use native command (oneapi-sr…
Browse files Browse the repository at this point in the history
…c#578)

Signed-off-by: JackAKirk <[email protected]>
Co-authored-by: Hugh Bird <[email protected]>
Co-authored-by: Rafal Bielski <[email protected]>
Co-authored-by: Romain Biessy <[email protected]>
  • Loading branch information
4 people authored Oct 14, 2024
1 parent f973570 commit 058ee95
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 69 deletions.
9 changes: 5 additions & 4 deletions src/dft/backends/cufft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "oneapi/mkl/dft/types.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"

#include <cufft.h>

Expand Down Expand Up @@ -71,7 +72,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -117,7 +118,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -171,7 +172,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -217,7 +218,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down
12 changes: 8 additions & 4 deletions src/dft/backends/cufft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef _ONEMKL_DFT_SRC_CUFFT_EXECUTE_HPP_
#define _ONEMKL_DFT_SRC_CUFFT_EXECUTE_HPP_
#ifndef _ONEMKL_DFT_SRC_EXECUTE_HELPER_CUFFT_HPP_
#define _ONEMKL_DFT_SRC_EXECUTE_HELPER_CUFFT_HPP_

#if __has_include(<sycl/sycl.hpp>)
#include <sycl/sycl.hpp>
Expand Down Expand Up @@ -125,12 +125,16 @@ void cufft_execute(const std::string &func, CUstream stream, cufftHandle plan, v
}
}
}

#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
// If not using the enqueue native extension, the host task must wait on the
// asynchronous operation to complete. Otherwise it report the operation
// as complete early.
auto result = cuStreamSynchronize(stream);
if (result != CUDA_SUCCESS) {
throw oneapi::mkl::exception("dft/backends/cufft", func,
"cuStreamSynchronize returned " + std::to_string(result));
}
#endif
}

inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, cufftHandle plan) {
Expand All @@ -145,4 +149,4 @@ inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, c

} // namespace oneapi::mkl::dft::cufft::detail

#endif
#endif // _ONEMKL_DFT_SRC_EXECUTE_HELPER_CUFFT_HPP_
9 changes: 5 additions & 4 deletions src/dft/backends/cufft/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "oneapi/mkl/dft/types.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"

#include <cufft.h>

Expand Down Expand Up @@ -74,7 +75,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -119,7 +120,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -173,7 +174,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -219,7 +220,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down
42 changes: 18 additions & 24 deletions src/dft/backends/rocfft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "oneapi/mkl/dft/descriptor.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"
#include "rocfft_handle.hpp"

#include <rocfft.h>
Expand Down Expand Up @@ -78,14 +79,13 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

auto inout_native = reinterpret_cast<void *>(
reinterpret_cast<fwd<descriptor_type> *>(detail::native_mem(ih, inout_acc)) +
offsets[0]);
detail::execute_checked(func_name, plan, &inout_native, nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &inout_native, nullptr, info);
});
});
}
Expand Down Expand Up @@ -113,7 +113,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_im_acc = inout_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{
Expand All @@ -124,8 +124,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
detail::native_mem(ih, inout_im_acc)) +
offsets[0])
};
detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info);
});
});
}
Expand All @@ -148,7 +147,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand All @@ -158,8 +157,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_native = reinterpret_cast<void *>(
reinterpret_cast<fwd<descriptor_type> *>(detail::native_mem(ih, out_acc)) +
offsets[1]);
detail::execute_checked(func_name, plan, &in_native, &out_native, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &in_native, &out_native, info);
});
});
}
Expand All @@ -184,7 +182,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_im_acc = out_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand All @@ -204,8 +202,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
detail::native_mem(ih, out_im_acc)) +
offsets[1])
};
detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info);
});
});
}
Expand Down Expand Up @@ -239,12 +236,11 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

void *inout_ptr = inout;
detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &inout_ptr, nullptr, info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand Down Expand Up @@ -273,12 +269,12 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{ inout_re + offsets[0], inout_im + offsets[0] };
detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info);

});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand All @@ -305,14 +301,13 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

void *in_ptr = in;
void *out_ptr = out;
detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &in_ptr, &out_ptr, info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand All @@ -336,15 +331,14 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name =
"compute_backward(desc, in_re, in_im, out_re, out_im, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> in_native{ in_re + offsets[0], in_im + offsets[0] };
std::array<void *, 2> out_native{ out_re + offsets[1], out_im + offsets[1] };
detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand Down
24 changes: 15 additions & 9 deletions src/dft/backends/rocfft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_
#define _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_
#ifndef _ONEMKL_DFT_SRC_EXECUTE_HELPER_ROCFFT_HPP_
#define _ONEMKL_DFT_SRC_EXECUTE_HELPER_ROCFFT_HPP_

#if __has_include(<sycl/sycl.hpp>)
#include <sycl/sycl.hpp>
Expand Down Expand Up @@ -76,22 +76,28 @@ inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &i
}

inline void sync_checked(const std::string &func, hipStream_t stream) {
auto result = hipStreamSynchronize(stream);
if (result != hipSuccess) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"hipStreamSynchronize returned " + std::to_string(result));
}
auto result = hipStreamSynchronize(stream);
if (result != hipSuccess) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"hipStreamSynchronize returned " + std::to_string(result));
}
}

inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[],
inline void execute_checked(const std::string &func, hipStream_t stream, const rocfft_plan plan, void *in_buffer[],
void *out_buffer[], rocfft_execution_info info) {
auto result = rocfft_execute(plan, in_buffer, out_buffer, info);
if (result != rocfft_status_success) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"rocfft_execute returned " + std::to_string(result));
}
#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
// If not using equeue native extension, the host task must wait on the
// asynchronous operation to complete. Otherwise it report the operation
// as complete early.
sync_checked(func, stream);
#endif
}

} // namespace oneapi::mkl::dft::rocfft::detail

#endif
#endif // _ONEMKL_DFT_SRC_EXECUTE_HELPER_ROCFFT_HPP_
Loading

0 comments on commit 058ee95

Please sign in to comment.