Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace all internal usage of get_upstream with get_upstream_resource #1491

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/rmm/mr/device/aligned_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ class aligned_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<aligned_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr && upstream_->is_equal(*cast->get_upstream()) &&
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource() &&
alignment_ == cast->alignment_ && alignment_threshold_ == cast->alignment_threshold_;
}

Expand Down
13 changes: 6 additions & 7 deletions include/rmm/mr/device/binning_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ class binning_memory_resource final : public device_memory_resource {
* Chooses a memory_resource that allocates the smallest blocks at least as large as `bytes`.
*
* @param bytes Requested allocation size in bytes
* @return rmm::mr::device_memory_resource& memory_resource that can allocate the requested size.
* @return Get the resource reference for the requested size.
*/
device_memory_resource* get_resource(std::size_t bytes)
rmm::device_async_resource_ref get_resource_ref(std::size_t bytes)
{
auto iter = resource_bins_.lower_bound(bytes);
return (iter != resource_bins_.cend()) ? iter->second
: static_cast<device_memory_resource*>(get_upstream());
return (iter != resource_bins_.cend()) ? rmm::device_async_resource_ref{iter->second}
: get_upstream_resource();
}

/**
Expand All @@ -170,7 +170,7 @@ class binning_memory_resource final : public device_memory_resource {
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
if (bytes <= 0) { return nullptr; }
return get_resource(bytes)->allocate(bytes, stream);
return get_resource_ref(bytes).allocate_async(bytes, stream);
}

/**
Expand All @@ -183,8 +183,7 @@ class binning_memory_resource final : public device_memory_resource {
*/
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override
{
auto res = get_resource(bytes);
if (res != nullptr) { res->deallocate(ptr, bytes, stream); }
get_resource_ref(bytes).deallocate_async(ptr, bytes, stream);
}

Upstream* upstream_mr_; // The upstream memory_resource from which to allocate blocks.
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/failure_callback_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

Upstream* upstream_; // the upstream resource used for satisfying allocation requests
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/fixed_size_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class fixed_size_memory_resource
*/
free_list blocks_from_upstream(cuda_stream_view stream)
{
void* ptr = get_upstream()->allocate(upstream_chunk_size_, stream);
void* ptr = get_upstream_resource().allocate_async(upstream_chunk_size_, stream);
block_type block{ptr};
upstream_blocks_.push_back(block);

Expand Down Expand Up @@ -211,7 +211,7 @@ class fixed_size_memory_resource
lock_guard lock(this->get_mutex());

for (auto block : upstream_blocks_) {
get_upstream()->deallocate(block.pointer(), upstream_chunk_size_);
get_upstream_resource().deallocate(block.pointer(), upstream_chunk_size_);
}
upstream_blocks_.clear();
}
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/limiting_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class limiting_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<limiting_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

// maximum bytes this allocator is allowed to allocate.
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/logging_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ class logging_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<logging_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

// make_logging_adaptor needs access to private get_default_filename
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/pool_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class pool_memory_resource final
if (size == 0) { return {}; }

try {
void* ptr = get_upstream()->allocate_async(size, stream);
void* ptr = get_upstream_resource().allocate_async(size, stream);
return std::optional<block_type>{
*upstream_blocks_.emplace(static_cast<char*>(ptr), size, true).first};
} catch (std::exception const& e) {
Expand Down Expand Up @@ -570,7 +570,7 @@ class pool_memory_resource final
lock_guard lock(this->get_mutex());

for (auto block : upstream_blocks_) {
get_upstream()->deallocate(block.pointer(), block.size());
get_upstream_resource().deallocate(block.pointer(), block.size());
}
upstream_blocks_.clear();
#ifdef RMM_POOL_TRACK_ALLOCATIONS
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/statistics_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ class statistics_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<statistics_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

counter bytes_; // peak, current and total allocated bytes
Expand Down
8 changes: 3 additions & 5 deletions include/rmm/mr/device/thread_safe_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,9 @@ class thread_safe_resource_adaptor final : public device_memory_resource {
bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto thread_safe_other = dynamic_cast<thread_safe_resource_adaptor<Upstream> const*>(&other);
if (thread_safe_other != nullptr) {
return upstream_->is_equal(*thread_safe_other->get_upstream());
}
return upstream_->is_equal(other);
auto cast = dynamic_cast<thread_safe_resource_adaptor<Upstream> const*>(&other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

std::mutex mutable mtx; // mutex for thread safe access to upstream
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/tracking_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ class tracking_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<tracking_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

bool capture_stacks_; // whether or not to capture call stacks
Expand Down
4 changes: 2 additions & 2 deletions tests/device_check_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class device_check_resource_adaptor final : public rmm::mr::device_memory_resour
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<device_check_resource_adaptor const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

rmm::cuda_device_id device_id;
Expand Down
9 changes: 0 additions & 9 deletions tests/mr/device/adaptor_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ TYPED_TEST(AdaptorTest, Equality)
}
}

TYPED_TEST(AdaptorTest, GetUpstream)
{
if constexpr (std::is_same_v<TypeParam, owning_wrapper>) {
EXPECT_TRUE(this->mr->wrapped().get_upstream()->is_equal(this->cuda));
} else {
EXPECT_TRUE(this->mr->get_upstream()->is_equal(this->cuda));
}
}

TYPED_TEST(AdaptorTest, GetUpstreamResource)
{
rmm::device_async_resource_ref expected{this->cuda};
Expand Down
5 changes: 3 additions & 2 deletions tests/mr/device/statistics_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ TEST(StatisticsTest, PeakAllocations)

TEST(StatisticsTest, MultiTracking)
{
statistics_adaptor mr{rmm::mr::get_current_device_resource()};
auto* orig_device_resource = rmm::mr::get_current_device_resource();
statistics_adaptor mr{orig_device_resource};
rmm::mr::set_current_device_resource(&mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
Expand Down Expand Up @@ -171,7 +172,7 @@ TEST(StatisticsTest, MultiTracking)
EXPECT_EQ(inner_mr.get_allocations_counter().peak, 5);

// Reset the current device resource
rmm::mr::set_current_device_resource(mr.get_upstream());
rmm::mr::set_current_device_resource(orig_device_resource);
}

TEST(StatisticsTest, NegativeInnerTracking)
Expand Down
5 changes: 3 additions & 2 deletions tests/mr/device/tracking_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ TEST(TrackingTest, AllocationsLeftWithoutStacks)

TEST(TrackingTest, MultiTracking)
{
tracking_adaptor mr{rmm::mr::get_current_device_resource(), true};
auto* orig_device_resource = rmm::mr::get_current_device_resource();
tracking_adaptor mr{orig_device_resource, true};
rmm::mr::set_current_device_resource(&mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
Expand Down Expand Up @@ -140,7 +141,7 @@ TEST(TrackingTest, MultiTracking)
EXPECT_EQ(inner_mr.get_allocated_bytes(), 0);

// Reset the current device resource
rmm::mr::set_current_device_resource(mr.get_upstream());
rmm::mr::set_current_device_resource(orig_device_resource);
}

TEST(TrackingTest, NegativeInnerTracking)
Expand Down
Loading