From 52e0d7331cb533955f479d82e4656253eaa9ef6f Mon Sep 17 00:00:00 2001 From: Michael Schellenberger Costa Date: Thu, 21 Mar 2024 16:15:41 -0700 Subject: [PATCH] Replace usages of raw `get_upstream` with `get_upstream_resource()` (#2207) We want to get rid of raw memory resources so move to the new interface instead Authors: - Michael Schellenberger Costa (https://github.com/miscco) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2207 --- cpp/test/core/device_resources_manager.cpp | 16 ++++++++-------- cpp/test/core/handle.cpp | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cpp/test/core/device_resources_manager.cpp b/cpp/test/core/device_resources_manager.cpp index c7c9e175ea..b9b8996a09 100644 --- a/cpp/test/core/device_resources_manager.cpp +++ b/cpp/test/core/device_resources_manager.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -114,17 +115,16 @@ TEST(DeviceResourcesManager, ObeysSetters) auto* mr = dynamic_cast*>( rmm::mr::get_current_device_resource()); - auto* workspace_mr = - dynamic_cast*>( - dynamic_cast*>( - res.get_workspace_resource()) - ->get_upstream()); + rmm::device_async_resource_ref workspace_mr = + dynamic_cast*>( + res.get_workspace_resource()) + ->get_upstream_resource(); if (upstream_mrs[i % devices.size()] != nullptr) { // Expect that the current memory resource is a pool memory resource as requested EXPECT_NE(mr, nullptr); - // Expect that the upstream workspace memory resource is a pool memory - // resource as requested - EXPECT_NE(workspace_mr, nullptr); + + // We cannot easily check the type of a resource_ref + (void)workspace_mr; } { diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 0b0b4b54ab..be18b0d5b4 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -281,7 +282,8 @@ TEST(Raft, WorkspaceResource) raft::handle_t handle; // The returned resource is always a limiting adaptor - auto* orig_mr = resource::get_workspace_resource(handle)->get_upstream(); + rmm::device_async_resource_ref orig_mr{ + resource::get_workspace_resource(handle)->get_upstream_resource()}; // Let's create a pooled resource auto pool_mr = std::shared_ptr{new rmm::mr::pool_memory_resource( @@ -295,8 +297,8 @@ TEST(Raft, WorkspaceResource) auto new_mr = resource::get_workspace_resource(handle); // By this point, the orig_mr likely points to a non-existent resource; don't dereference! - ASSERT_NE(orig_mr, new_mr); - ASSERT_EQ(pool_mr.get(), new_mr->get_upstream()); + ASSERT_NE(orig_mr, rmm::device_async_resource_ref{new_mr}); + ASSERT_EQ(rmm::device_async_resource_ref{pool_mr.get()}, new_mr->get_upstream_resource()); // We can safely reset pool_mr, because the shared_ptr to the pool memory stays in the resource pool_mr.reset();