Skip to content

Commit

Permalink
Merge branch 'fea-device-buffer-multidevice' of github.com:harrism/rm…
Browse files Browse the repository at this point in the history
…m into fea-device-buffer-multidevice
  • Loading branch information
harrism committed Nov 9, 2023
2 parents 8cadcd6 + 2381a79 commit 6a403e0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 37 deletions.
2 changes: 1 addition & 1 deletion include/rmm/device_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ class device_buffer {
{
if (bytes > 0) {
RMM_EXPECTS(nullptr != source, "Invalid copy from nullptr.");
RMM_EXPECTS(nullptr != _data, "Invalid copy from nullptr.");
RMM_EXPECTS(nullptr != _data, "Invalid copy to nullptr.");

RMM_CUDA_TRY(cudaMemcpyAsync(_data, source, bytes, cudaMemcpyDefault, stream().value()));
}
Expand Down
50 changes: 14 additions & 36 deletions include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@
#include <fmt/core.h>

#include <cstddef>
#include <functional>
#include <limits>
#include <map>
#include <mutex>
#include <set>
#include <thread>
#include <unordered_map>

namespace rmm::mr::detail {
Expand Down Expand Up @@ -259,23 +255,6 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
}

private:
/**
* @brief RAII wrapper for a CUDA event.
*/
struct event_wrapper {
event_wrapper()
{
RMM_ASSERT_CUDA_SUCCESS(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
}
~event_wrapper() { RMM_ASSERT_CUDA_SUCCESS(cudaEventDestroy(event)); }
cudaEvent_t event{};

event_wrapper(event_wrapper const&) = delete;
event_wrapper& operator=(event_wrapper const&) = delete;
event_wrapper(event_wrapper&&) noexcept = delete;
event_wrapper& operator=(event_wrapper&&) = delete;
};

/**
* @brief get a unique CUDA event (possibly new) associated with `stream`
*
Expand All @@ -289,17 +268,20 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
stream_event_pair get_event(cuda_stream_view stream)
{
if (stream.is_per_thread_default()) {
// Create a thread-local shared event wrapper for each device. Shared pointers in the thread
// and in each MR instance ensure the wrappers are destroyed only after all are finished
// with them.
thread_local std::vector<std::shared_ptr<event_wrapper>> events_tls(
rmm::get_num_cuda_devices());
auto event = [&, device_id = this->device_id_]() {
if (events_tls[device_id.value()]) { return events_tls[device_id.value()]->event; }

auto event = std::make_shared<event_wrapper>();
this->default_stream_events.insert(event);
return (events_tls[device_id.value()] = std::move(event))->event;
// Create a thread-local event for each device. These events are
// deliberately leaked since the destructor needs to call into
// the CUDA runtime and thread_local destructors (can) run below
// main: it is undefined behaviour to call into the CUDA
// runtime below main.
thread_local std::vector<cudaEvent_t> events_tls(rmm::get_num_cuda_devices());
auto event = [device_id = this->device_id_]() {
auto& e = events_tls[device_id.value()];
if (!e) {
// These events are deliberately not destructed and therefore live until
// program exit.
RMM_ASSERT_CUDA_SUCCESS(cudaEventCreateWithFlags(&e, cudaEventDisableTiming));
}
return e;
}();
return stream_event_pair{stream.value(), event};
}
Expand Down Expand Up @@ -505,10 +487,6 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
// bidirectional mapping between non-default streams and events
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;

// shared pointers to events keeps the events alive as long as either the thread that created
// them or the MR that is using them exists.
std::set<std::shared_ptr<event_wrapper>> default_stream_events;

std::mutex mtx_; // mutex for thread-safe access

rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()};
Expand Down

0 comments on commit 6a403e0

Please sign in to comment.