From 42a2d4a721469803882957cd458b885294e2bc59 Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Tue, 10 Dec 2024 14:24:41 -0800 Subject: [PATCH 1/3] fix style --- cpp/src/c_api/neighbor_sampling.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 37982eab82..2cc9646309 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -880,7 +880,6 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { handle_.get_stream()); std::optional> start_vertex_labels{std::nullopt}; - std::optional> local_label_to_comm_rank{std::nullopt}; std::optional> label_to_comm_rank{ std::nullopt}; // global after allgatherv @@ -932,12 +931,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { handle_.get_stream(), raft::device_span{unique_labels.data(), unique_labels.size()}); - (*local_label_to_comm_rank).resize(num_unique_labels, handle_.get_stream()); + rmm::device_uvector local_label_to_comm_rank(num_unique_labels, + handle_.get_stream()); cugraph::detail::scalar_fill( handle_.get_stream(), - (*local_label_to_comm_rank).begin(), // This should be rename to rank - (*local_label_to_comm_rank).size(), + local_label_to_comm_rank.begin(), // This should be rename to rank + local_label_to_comm_rank.size(), label_t{handle_.get_comms().get_rank()}); // Perform allgather to get global_label_to_comm_rank_d_vector @@ -948,11 +948,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { std::exclusive_scan( recvcounts.begin(), recvcounts.end(), displacements.begin(), size_t{0}); - (*label_to_comm_rank) - .resize(displacements.back() + recvcounts.back(), handle_.get_stream()); + rmm::device_uvector tmp_label_to_comm_rank( + displacements.back() + recvcounts.back(), handle_.get_stream()); + + label_to_comm_rank = std::move(tmp_label_to_comm_rank); cugraph::device_allgatherv(handle_.get_comms(), - (*local_label_to_comm_rank).begin(), + local_label_to_comm_rank.begin(), (*label_to_comm_rank).begin(), recvcounts, displacements, From 949f75d5b676be67920fc446f0a1b28529d7867d Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Mon, 6 Jan 2025 09:56:12 -0800 Subject: [PATCH 2/3] remove tmp assignment --- cpp/src/c_api/neighbor_sampling.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 2cc9646309..7193032317 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -948,11 +948,9 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { std::exclusive_scan( recvcounts.begin(), recvcounts.end(), displacements.begin(), size_t{0}); - rmm::device_uvector tmp_label_to_comm_rank( + label_to_comm_rank = rmm::device_uvector( displacements.back() + recvcounts.back(), handle_.get_stream()); - label_to_comm_rank = std::move(tmp_label_to_comm_rank); - cugraph::device_allgatherv(handle_.get_comms(), local_label_to_comm_rank.begin(), (*label_to_comm_rank).begin(), From a40e6f52fccc9fbd5beb88353c3f253c001577f2 Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Mon, 6 Jan 2025 11:09:32 -0800 Subject: [PATCH 3/3] update copyright --- cpp/src/c_api/neighbor_sampling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 7193032317..9c6c1f0f02 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.