Skip to content

Commit

Permalink
[BLAS] Simplify CublasScopedContextHandler (#609)
Browse files Browse the repository at this point in the history
  • Loading branch information
konradkusiak97 authored Nov 18, 2024
1 parent 8f4312e commit c0cef0c
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 183 deletions.
30 changes: 12 additions & 18 deletions src/blas/backends/cublas/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,30 @@
**************************************************************************/
#ifndef CUBLAS_HANDLE_HPP
#define CUBLAS_HANDLE_HPP
#include <atomic>
#include <unordered_map>
#include "cublas_helper.hpp"

namespace oneapi {
namespace mkl {
namespace blas {
namespace cublas {

template <typename T>
struct cublas_handle {
using handle_container_t = std::unordered_map<T, std::atomic<cublasHandle_t>*>;
using handle_container_t = std::unordered_map<CUdevice, cublasHandle_t>;
handle_container_t cublas_handle_mapper_{};
~cublas_handle() noexcept(false) {
CUresult err;
CUcontext original;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original);
for (auto& handle_pair : cublas_handle_mapper_) {
cublasStatus_t err;
if (handle_pair.second != nullptr) {
auto handle = handle_pair.second->exchange(nullptr);
if (handle != nullptr) {
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already
// destroyed by the ContextCallback and we're free to delete the
// atomic object.
delete handle_pair.second;
}

handle_pair.second = nullptr;
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, handle_pair.first);
if (original != desired) {
// Sets the desired context as the active one for the thread in order to destroy its corresponding cublasHandle_t.
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
}
cublasStatus_t err;
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second);
}
cublas_handle_mapper_.clear();
}
Expand Down
113 changes: 17 additions & 96 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
*
**************************************************************************/
#include "cublas_scope_handle.hpp"
#if __has_include(<sycl/detail/common.hpp>)
#include <sycl/detail/common.hpp>
#else
#include <CL/sycl/detail/common.hpp>
#endif

namespace oneapi {
namespace mkl {
Expand All @@ -35,108 +30,34 @@ namespace cublas {
* takes place if no other element in the container has a key equivalent to
* the one being emplaced (keys in a map container are unique).
*/
#ifdef ONEMKL_PI_INTERFACE_REMOVED
thread_local cublas_handle<ur_context_handle_t> CublasScopedContextHandler::handle_helper =
cublas_handle<ur_context_handle_t>{};
#else
thread_local cublas_handle<pi_context> CublasScopedContextHandler::handle_helper =
cublas_handle<pi_context>{};
#endif
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih)
: ih(ih),
needToRecover_(false) {
placedContext_ = new sycl::context(queue.get_context());
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult err;
CUcontext desired;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_);
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice);
if (original_ != desired) {
// Sets the desired context as the active one for the thread
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
// No context is installed and the suggested context is primary
// This is the most common case. We can activate the context in the
// thread and leave it there until all the PI context referring to the
// same underlying CUDA primary context are destroyed. This emulates
// the behaviour of the CUDA runtime api, and avoids costly context
// switches. No action is required on this side of the if.
needToRecover_ = !(original_ == nullptr);
}
}

CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) {
if (needToRecover_) {
CUresult err;
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_);
}
delete placedContext_;
}

void ContextCallback(void* userData) {
auto* ptr = static_cast<std::atomic<cublasHandle_t>*>(userData);
if (!ptr) {
return;
}
auto handle = ptr->exchange(nullptr);
if (handle != nullptr) {
cublasStatus_t err1;
CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle);
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already destroyed by
// the cublas_handle destructor and we're free to delete the atomic
// object.
delete ptr;
}
}
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) {
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult cuErr;
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice);
#ifdef ONEMKL_PI_INTERFACE_REMOVED
auto piPlacedContext_ = reinterpret_cast<ur_context_handle_t>(desired);
#else
auto piPlacedContext_ = reinterpret_cast<pi_context>(desired);
#endif
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUstream streamId = get_stream(queue);
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_);

auto it = handle_helper.cublas_handle_mapper_.find(device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
if (it->second == nullptr) {
handle_helper.cublas_handle_mapper_.erase(it);
}
else {
auto handle = it->second->load();
if (handle != nullptr) {
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
else {
handle_helper.cublas_handle_mapper_.erase(it);
}
cublasHandle_t nativeHandle = it->second;
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
}
return nativeHandle;
}

cublasHandle_t handle;

CUBLAS_ERROR_FUNC(cublasCreate, err, &handle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);

auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(piPlacedContext_, new std::atomic<cublasHandle_t>(handle)));
cublasHandle_t nativeHandle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);

sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback,
insert_iter.first->second);
auto insert_iter =
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));

return handle;
return nativeHandle;
}

CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
Expand Down
35 changes: 2 additions & 33 deletions src/blas/backends/cublas/cublas_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,9 @@
#else
#include <CL/sycl.hpp>
#endif
#if __has_include(<sycl/context.hpp>)
#if __SYCL_COMPILER_VERSION <= 20220930
#include <sycl/backend/cuda.hpp>
#endif
#include <sycl/context.hpp>
#else
#include <CL/sycl/backend/cuda.hpp>
#include <CL/sycl/context.hpp>
#endif

// After Plugin Interface removal in DPC++ ur.hpp is the new include
#if __has_include(<sycl/detail/ur.hpp>)
#include <sycl/detail/ur.hpp>
#ifndef ONEMKL_PI_INTERFACE_REMOVED
#define ONEMKL_PI_INTERFACE_REMOVED
#endif
#elif __has_include(<sycl/detail/pi.hpp>)
#include <sycl/detail/pi.hpp>
#else
#include <CL/sycl/detail/pi.hpp>
#endif

#include <atomic>
#include <memory>
#include <thread>
#include <unordered_map>
#include "cublas_helper.hpp"
#include "cublas_handle.hpp"

Expand Down Expand Up @@ -84,22 +61,14 @@ the handle must be destroyed when the context goes out of scope. This will bind
**/

class CublasScopedContextHandler {
CUcontext original_;
sycl::context* placedContext_;
bool needToRecover_;
sycl::interop_handle& ih;
#ifdef ONEMKL_PI_INTERFACE_REMOVED
static thread_local cublas_handle<ur_context_handle_t> handle_helper;
#else
static thread_local cublas_handle<pi_context> handle_helper;
#endif
static thread_local cublas_handle handle_helper;
CUstream get_stream(const sycl::queue& queue);
sycl::context get_context(const sycl::queue& queue);

public:
CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih);
CublasScopedContextHandler(sycl::interop_handle& ih);

~CublasScopedContextHandler() noexcept(false);
/**
* @brief get_handle: creates the handle by implicitly impose the advice
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
Expand Down
32 changes: 11 additions & 21 deletions src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,33 @@ namespace mkl {
namespace blas {
namespace cublas {

thread_local cublas_handle<int> CublasScopedContextHandler::handle_helper = cublas_handle<int>{};
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih)
: interop_h(ih) {}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) {
sycl::device device = queue.get_device();
int current_device = interop_h.get_native_device<sycl::backend::cuda>();
CUdevice current_device = interop_h.get_native_device<sycl::backend::cuda>();
CUstream streamId = get_stream(queue);
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(current_device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
if (it->second == nullptr) {
handle_helper.cublas_handle_mapper_.erase(it);
}
else {
auto handle = it->second->load();
if (handle != nullptr) {
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
else {
handle_helper.cublas_handle_mapper_.erase(it);
}
cublasHandle_t handle = it->second;
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
cublasHandle_t handle;

CUBLAS_ERROR_FUNC(cublasCreate, err, &handle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);

auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(current_device, new std::atomic<cublasHandle_t>(handle)));
auto insert_iter =
handle_helper.cublas_handle_mapper_.insert(std::make_pair(current_device, handle));
return handle;
}

Expand All @@ -71,4 +61,4 @@ CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
} // namespace cublas
} // namespace blas
} // namespace mkl
} // namespace oneapi
} // namespace oneapi
3 changes: 1 addition & 2 deletions src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#endif
#include <memory>
#include <thread>
#include <unordered_map>
#include "cublas_helper.hpp"
#include "cublas_handle.hpp"
namespace oneapi {
Expand Down Expand Up @@ -60,7 +59,7 @@ the handle must be destroyed when the context goes out of scope. This will bind

class CublasScopedContextHandler {
sycl::interop_handle interop_h;
static thread_local cublas_handle<int> handle_helper;
static thread_local cublas_handle handle_helper;
sycl::context get_context(const sycl::queue& queue);
CUstream get_stream(const sycl::queue& queue);

Expand Down
14 changes: 1 addition & 13 deletions src/blas/backends/cublas/cublas_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,6 @@
#else
#include "cublas_scope_handle_hipsycl.hpp"

// After Plugin Interface removal in DPC++ ur.hpp is the new include
#if __has_include(<sycl/detail/ur.hpp>)
#include <sycl/detail/ur.hpp>
#ifndef ONEMKL_PI_INTERFACE_REMOVED
#define ONEMKL_PI_INTERFACE_REMOVED
#endif
#elif __has_include(<sycl/detail/pi.hpp>)
#include <sycl/detail/pi.hpp>
#else
#include <CL/sycl/detail/pi.hpp>
#endif

namespace sycl {
using interop_handler = sycl::interop_handle;
}
Expand All @@ -72,7 +60,7 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) {
#else
cgh.host_task([f, queue](sycl::interop_handle ih) {
#endif
auto sc = CublasScopedContextHandler(queue, ih);
auto sc = CublasScopedContextHandler(ih);
f(sc);
});
}
Expand Down

0 comments on commit c0cef0c

Please sign in to comment.