From 6223fd2ff33749eb843e40acf05b2d92b3a61be4 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Thu, 26 Sep 2024 20:49:14 +0900 Subject: [PATCH 01/12] [Improved Multi-CTA algo] Address low recall issue of multi-CTA algo when the number of results is large Fix some issues Fix lower recall issue with new multi-cta algo Removing redundant code and changing some parameters Update cpp/src/neighbors/detail/cagra/search_plan.cuh Co-authored-by: Tamas Bela Feher Remove an unnecessary line and satisfy clang-format --- .../neighbors/detail/cagra/cagra_search.cuh | 2 +- .../neighbors/detail/cagra/device_common.hpp | 37 ++- cpp/src/neighbors/detail/cagra/factory.cuh | 9 +- cpp/src/neighbors/detail/cagra/hashmap.hpp | 87 +++++-- .../detail/cagra/search_multi_cta.cuh | 16 +- .../detail/cagra/search_multi_cta_inst.cuh | 1 + .../cagra/search_multi_cta_kernel-inl.cuh | 213 ++++++++++-------- .../detail/cagra/search_multi_cta_kernel.cuh | 5 +- .../detail/cagra/search_multi_kernel.cuh | 3 +- .../neighbors/detail/cagra/search_plan.cuh | 170 ++++++++------ .../detail/cagra/search_single_cta.cuh | 3 +- .../cagra/search_single_cta_kernel-inl.cuh | 6 +- 12 files changed, 347 insertions(+), 205 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 5778d85a6..b4f701819 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res, using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; std::unique_ptr> plan = factory::create( - res, params, dataset_desc, queries.extent(1), graph.extent(1), topk); + res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk); plan->check(topk); diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 7ec3d4d9e..9dcc5123b 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( const IndexT* __restrict__ seed_ptr, // [num_seeds] const uint32_t num_seeds, IndexT* __restrict__ visited_hash_ptr, - const uint32_t hash_bitlen, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hash_ptr, + const uint32_t traversed_hash_bitlen, const uint32_t block_id = 0, const uint32_t num_blocks = 1) { @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); if (valid_i && lane_id == 0) { - if (best_index_team_local != raft::upper_bound() && - hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { - result_distances_ptr[i] = best_norm2_team_local; - result_indices_ptr[i] = best_index_team_local; - } else { - result_distances_ptr[i] = raft::upper_bound(); - result_indices_ptr[i] = raft::upper_bound(); + if (best_index_team_local != raft::upper_bound()) { + if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } else if ((traversed_hash_ptr != nullptr) && + hashmap::search( + traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { + // Deactivate this entry as it has been already used by otehrs. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } } + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; } } } @@ -168,7 +177,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const uint32_t knn_k, // hashmap IndexT* __restrict__ visited_hashmap_ptr, - const uint32_t hash_bitlen, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, const IndexT* __restrict__ parent_indices, const IndexT* __restrict__ internal_topk_list, const uint32_t search_width) @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; } if (child_id != invalid_index) { - if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { + if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + child_id = invalid_index; + } else if ((traversed_hashmap_ptr != nullptr) && + hashmap::search( + traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { + // Deactivate this entry as this has been already used by others. child_id = invalid_index; } } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index e6e7ff64f..064f880ad 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -40,10 +40,11 @@ class factory { search_params const& params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) { - search_plan_impl_base plan(params, dim, graph_degree, topk); + search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk); return dispatch_kernel(res, plan, dataset_desc); } @@ -56,15 +57,15 @@ class factory { if (plan.algo == search_algo::SINGLE_CTA) { return std::make_unique< single_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else if (plan.algo == search_algo::MULTI_CTA) { return std::make_unique< multi_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else { return std::make_unique< multi_kernel_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } } }; diff --git a/cpp/src/neighbors/detail/cagra/hashmap.hpp b/cpp/src/neighbors/detail/cagra/hashmap.hpp index 2c62dda90..6dbdd5a8a 100644 --- a/cpp/src/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/src/neighbors/detail/cagra/hashmap.hpp @@ -23,6 +23,8 @@ #include +#define HASHMAP_LINEAR_PROBING + // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored // #pragma GCC diagnostic pop @@ -42,7 +44,7 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table, } } -template +template RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) @@ -50,7 +52,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, // Open addressing is used for collision resolution const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; -#if 1 +#ifdef HASHMAP_LINEAR_PROBING // Linear probing IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; @@ -59,32 +61,91 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, uint32_t index = key & bit_mask; const uint32_t stride = (key >> bitlen) * 2 + 1; #endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { - const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); - if (old == ~static_cast(0)) { + const IdxT old = atomicCAS(&table[index], hashval_empty, key); + if (old == hashval_empty) { return 1; } else if (old == key) { return 0; + } else if (SUPPORT_REMOVE) { + // Checks if this key has been removed before. + const uint32_t old = atomicCAS(&table[index], removed_key, key); + if (old == removed_key) { + return 1; + } else if (old == key) { + return 0; + } } index = (index + stride) & bit_mask; } return 0; } -template -RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, - const uint32_t bitlen, - const IdxT key) +template +RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key) { - IdxT ret = 0; - if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } - for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { - ret |= __shfl_xor_sync(0xffffffff, ret, offset); + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#ifdef HASHMAP_LINEAR_PROBING + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + for (unsigned i = 0; i < size; i++) { + const IdxT val = table[index]; + if (val == key) { + return 1; + } else if (val == hashval_empty) { + return 0; + } else if (SUPPORT_REMOVE) { + // Check if this key has been removed. + if (val == removed_key) { + return 0; + } + } + index = (index + stride) & bit_mask; } - return ret; + return 0; } template +RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key) +{ + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#ifdef HASHMAP_LINEAR_PROBING + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + for (unsigned i = 0; i < size; i++) { + // To remove a key, set the MSB to 1. + const uint32_t old = atomicCAS(&table[index], key, removed_key); + if (old == key) { + return 1; + } else if (old == hashval_empty) { + return 0; + } + index = (index + stride) & bit_mask; + } + return 0; +} + +template RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key) { diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index ecfd856f1..8d425ca67 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -102,24 +102,24 @@ struct search : public search_plan_impl& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk), + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), intermediate_indices(res), intermediate_distances(res), topk_workspace(res) - { set_params(res, params); } void set_params(raft::resources const& res, const search_params& params) { - constexpr unsigned muti_cta_itopk_size = 32; - this->itopk_size = muti_cta_itopk_size; - search_width = 1; + constexpr unsigned multi_cta_itopk_size = 32; + this->itopk_size = multi_cta_itopk_size; + search_width = 1; num_cta_per_query = - max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size)); + max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size)); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -128,7 +128,8 @@ struct search : public search_plan_impl +template RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( - INDEX_T* const next_parent_indices, // [search_width] - const uint32_t search_width, - INDEX_T* const itopk_indices, // [num_itopk] - const size_t num_itopk, - uint32_t* const terminate_flag) + INDEX_T* const next_parent_indices, // [num_parents] + const uint32_t num_parents, + INDEX_T* const itopk_indices, // [num_itopk] + DISTANCE_T* const itopk_distances, // [num_itopk] + const uint32_t num_itopk, // (*) num_itopk <= 32 + INDEX_T* const hash_ptr, + const uint32_t hash_bitlen) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; - for (uint32_t i = lane_id; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - uint32_t max_itopk = num_itopk; - if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } - uint32_t num_new_parents = 0; - for (uint32_t j = lane_id; j < max_itopk; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < num_itopk) { - index = itopk_indices[j]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } + + // Initialize + if (lane_id < num_parents) { next_parent_indices[lane_id] = ~static_cast(0); } + INDEX_T index = ~static_cast(0); + if (lane_id < num_itopk) { index = itopk_indices[lane_id]; } + + int is_candidate = 0; + if ((index & index_msb_1_mask) == 0) { + if (hashmap::search(hash_ptr, hash_bitlen, index)) { + // Deactivate nodes that have already been used by other CTAs. + index = ~static_cast(0); + itopk_indices[lane_id] = index; + itopk_distances[lane_id] = utils::get_max_value(); + } else { + is_candidate = 1; } - const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = j; - itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node + } + + uint32_t num_next_parents = 0; + while (num_next_parents < num_parents) { + const uint32_t ballot_mask = __ballot_sync(0xffffffff, is_candidate); + int num_candidates = __popc(ballot_mask); + if (num_candidates == 0) { return; } + int is_found = 0; + if (is_candidate) { + const auto candidate_id = __popc(ballot_mask & ((1 << lane_id) - 1)); + if (candidate_id == 0) { + if (hashmap::insert(hash_ptr, hash_bitlen, index)) { + // Use this candidate as next parent + next_parent_indices[num_next_parents] = lane_id; + index |= index_msb_1_mask; // set most significant bit as used node + is_found = 1; + } else { + // Deactivate the node since it has been used by other CTA. + index = ~static_cast(0); + itopk_distances[lane_id] = utils::get_max_value(); + } + itopk_indices[lane_id] = index; + is_candidate = 0; } } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } + if (__ballot_sync(0xffffffff, is_found)) { num_next_parents += 1; } } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } template @@ -121,12 +139,12 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( } /* Warp Sort */ bitonic::warp_sort(key, val); - /* Store itopk sorted results */ + /* Store sorted results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - distances[j] = key[i]; - indices[j] = val[i]; + if (j < num_elements) { + indices[j] = val[i]; + if (j < num_itopk) { distances[j] = key[i]; } } } } @@ -148,9 +166,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( const uint64_t rand_xor_mask, const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, + const uint32_t visited_hash_bitlen, typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const uint32_t hash_bitlen, + traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] + const uint32_t traversed_hash_bitlen, const uint32_t itopk_size, const uint32_t search_width, const uint32_t min_iteration, @@ -185,11 +204,11 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( extern __shared__ uint8_t smem[]; // Layout of result_buffer - // +----------------+------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | + // +----------------+-------------------------------+---------+ + // | internal_top_k | neighbors of parent nodes | padding | // | | | upto 32 | - // +----------------+------------------------------+---------+ - // |<--- result_buffer_size --->| + // +----------------+-------------------------------+---------+ + // |<--- result_buffer_size --->| const auto result_buffer_size = itopk_size + (search_width * graph_degree); const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); assert(result_buffer_size_32 <= MAX_ELEMENTS); @@ -201,10 +220,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( reinterpret_cast(smem + dataset_desc->smem_ws_size_in_bytes()); auto* __restrict__ result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto* __restrict__ parent_indices_buffer = + auto* __restrict__ local_visited_hashmap_ptr = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ terminate_flag = - reinterpret_cast(parent_indices_buffer + search_width); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); #if 0 /* debug */ @@ -214,9 +233,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( } #endif - if (threadIdx.x == 0) { terminate_flag[0] = 0; } - INDEX_T* const local_visited_hashmap_ptr = - visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); + + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); __syncthreads(); _CLK_REC(clk_init); @@ -235,35 +255,51 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, block_id, num_blocks); __syncthreads(); _CLK_REC(clk_compute_1st_distance); - uint32_t iter = 0; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + uint32_t iter = 0; while (1) { - // topk with bitonic sort + // Topk with bitonic sort (1st warp only) _CLK_START(); topk_by_bitonic_sort(result_distances_buffer, result_indices_buffer, itopk_size + (search_width * graph_degree), itopk_size); _CLK_REC(clk_topk); + __syncthreads(); - if (iter + 1 == max_iteration) { - __syncthreads(); - break; + if (iter + 1 == max_iteration) { break; } + + // Remove entries kicked out of the itopk list from the traversed hash table. + for (unsigned i = threadIdx.x; i < search_width * graph_degree; i += blockDim.x) { + INDEX_T index = result_indices_buffer[itopk_size + i]; + if ((index & index_msb_1_mask) == 0 || (index == ~static_cast(0))) { continue; } + index &= ~index_msb_1_mask; + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); } - // pick up next parents + // Pick up next parents (1st warp only) _CLK_START(); - pickup_next_parents( - parent_indices_buffer, search_width, result_indices_buffer, itopk_size, terminate_flag); + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); _CLK_REC(clk_pickup_parents); - __syncthreads(); - if (*terminate_flag && iter >= min_iteration) { break; } + + if ((parent_indices_buffer[0] == ~static_cast(0)) && (iter >= min_iteration)) { + break; + } // compute the norms between child nodes and query node _CLK_START(); @@ -273,7 +309,9 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( knn_graph, graph_degree, local_visited_hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, parent_indices_buffer, result_indices_buffer, search_width); @@ -303,36 +341,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( iter++; } - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); - __syncthreads(); - } - for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit + uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + INDEX_T index = result_indices_buffer[i]; + DISTANCE_T distance = result_distances_buffer[i]; + if (index & index_msb_1_mask) { + index &= ~index_msb_1_mask; // clear most significant bit + } else { + // This entry has not been used as parent, so deactivate this. + index = ~static_cast(0); + distance = utils::get_max_value(); + } + result_indices_ptr[j] = index; + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = distance; } } if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { @@ -427,8 +448,9 @@ void select_and_run(const dataset_descriptor_host& dat uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_seeds, SampleFilterT sample_filter, @@ -441,9 +463,13 @@ void select_and_run(const dataset_descriptor_host& dat RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Initialize hash table - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap_ptr, hash_size, utils::get_max_value(), hash_size, num_queries, stream); + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + num_queries, + stream); dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); @@ -463,8 +489,9 @@ void select_and_run(const dataset_descriptor_host& dat ps.rand_xor_mask, dev_seed_ptr, num_seeds, - hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, ps.itopk_size, ps.search_width, ps.min_iterations, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh index 1a1dcd579..e5dc29f27 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -36,8 +36,9 @@ void select_and_run(const dataset_descriptor_host& dat uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_seeds, SampleFilterT sample_filter, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index c6fe21642..be92be999 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -635,9 +635,10 @@ struct search : search_plan_impl { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk), + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), result_indices(res), result_distances(res), parent_node_list(res), diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 99254aa50..2bbf3d56a 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -108,11 +108,13 @@ struct lightweight_uvector { }; struct search_plan_impl_base : public search_params { + int64_t dataset_size; int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) + search_plan_impl_base(search_params params, int64_t dim, int64_t dataset_size, + int64_t graph_degree, uint32_t topk) + : search_params(params), dim(dim), dataset_size(dataset_size), graph_degree(graph_degree), topk(topk) { if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); @@ -141,7 +143,6 @@ struct search_plan_impl : public search_plan_impl_base { size_t small_hash_bitlen; size_t small_hash_reset_interval; size_t hashmap_size; - uint32_t dataset_size; uint32_t result_buffer_size; uint32_t smem_size; @@ -157,9 +158,10 @@ struct search_plan_impl : public search_plan_impl_base { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : search_plan_impl_base(params, dim, graph_degree, topk), + : search_plan_impl_base(params, dim, dataset_size, graph_degree, topk), hashmap(res), num_executed_iterations(res), dev_seed(res), @@ -193,10 +195,16 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t _max_iterations = max_iterations; if (max_iterations == 0) { if (algo == search_algo::MULTI_CTA) { - _max_iterations = 1 + std::min(32 * 1.1, 32 + 10.0); // TODO(anaruse) + constexpr uint32_t mc_itopk_size = 32; + constexpr uint32_t mc_search_width = 1; + _max_iterations = mc_itopk_size / mc_search_width; } else { - _max_iterations = - 1 + std::min((itopk_size / search_width) * 1.1, (itopk_size / search_width) + 10.0); + _max_iterations = itopk_size / search_width; + } + int64_t num_reachable_nodes = 1; + while (num_reachable_nodes < dataset_size) { + num_reachable_nodes *= graph_degree / 2; + _max_iterations += 1; } } if (max_iterations < min_iterations) { _max_iterations = min_iterations; } @@ -219,88 +227,106 @@ struct search_plan_impl : public search_plan_impl_base { // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size inline void calc_hashmap_params(raft::resources const& res) { - // for multiple CTA search - uint32_t mc_num_cta_per_query = 0; - uint32_t mc_search_width = 0; - uint32_t mc_itopk_size = 0; - if (algo == search_algo::MULTI_CTA) { - mc_itopk_size = 32; - mc_search_width = 1; - mc_num_cta_per_query = max(search_width, raft::ceildiv(itopk_size, (size_t)32)); - RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); - RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); - RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); - } - // Determine hash size (bit length) hashmap_size = 0; hash_bitlen = 0; small_hash_bitlen = 0; small_hash_reset_interval = 1024 * 1024; float max_fill_rate = hashmap_max_fill_rate; - while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { - // - // The small-hash reduces hash table size by initializing the hash table - // for each iteration and re-registering only the nodes that should not be - // re-visited in that iteration. Therefore, the size of small-hash should - // be determined based on the internal topk size and the number of nodes - // visited per iteration. - // - const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); - unsigned min_bitlen = 8; // 256 - unsigned max_bitlen = 13; // 8K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - if (hash_bitlen > max_bitlen) { - // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. - if (hashmap_mode == hash_mode::AUTO) { - hash_bitlen = 0; - break; - } else { - RAFT_FAIL( - "small-hash cannot be used because the required hash size exceeds the limit (%u)", - hashmap::get_size(max_bitlen)); - } - } - small_hash_bitlen = hash_bitlen; + if (algo == search_algo::MULTI_CTA) { + const uint32_t mc_itopk_size = 32; + const uint32_t mc_num_cta_per_query = + max(search_width, raft::ceildiv(itopk_size, (size_t)mc_itopk_size)); + RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); + RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); // - // Sincc the hash table size is limited to a power of 2, the requirement, - // the maximum fill rate, may be satisfied even if the frequency of hash - // table reset is reduced to once every 2 or more iterations without - // changing the hash table size. In that case, reduce the reset frequency. + // [visited_hash_table] + // In the multi CTA algo, which node has been visited is managed in a hash + // table that each CTA has in the shared memory. This hash table is not + // shared among CTAs. // - small_hash_reset_interval = 1; - while (1) { - const auto max_visited_nodes = - itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); - if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } - small_hash_reset_interval += 1; + const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations); + small_hash_bitlen = 11; // 2K + while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { + small_hash_bitlen += 1; } - break; - } - if (hash_bitlen == 0) { + RAFT_EXPECTS(small_hash_bitlen <= 14, "small_hash_bitlen cannot be largen than 14 (16K)"); // - // The size of hash table is determined based on the maximum number of - // nodes that may be visited before the search is completed and the - // maximum fill rate of the hash table. + // [traversed_hash_table] + // Whether a node has ever been used as the starting point for a traversal + // in each iteration is managed in a separate hash table, which is shared + // among the CTAs. // - uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); - if (algo == search_algo::MULTI_CTA) { - max_visited_nodes = mc_itopk_size + (mc_search_width * graph_degree * max_iterations); - max_visited_nodes *= mc_num_cta_per_query; - } + const auto max_traversed_nodes = mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); unsigned min_bitlen = 11; // 2K if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + while (max_traversed_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { hash_bitlen += 1; } - RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); + RAFT_EXPECTS(hash_bitlen <= 25, "hash_bitlen cannot be largen than 25 (32M)"); + } else { + while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { + // + // The small-hash reduces hash table size by initializing the hash table + // for each iteration and re-registering only the nodes that should not be + // re-visited in that iteration. Therefore, the size of small-hash should + // be determined based on the internal topk size and the number of nodes + // visited per iteration. + // + const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); + unsigned min_bitlen = 8; // 256 + unsigned max_bitlen = 13; // 8K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + if (hash_bitlen > max_bitlen) { + // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. + if (hashmap_mode == hash_mode::AUTO) { + hash_bitlen = 0; + break; + } else { + RAFT_FAIL( + "small-hash cannot be used because the required hash size exceeds the limit (%u)", + hashmap::get_size(max_bitlen)); + } + } + small_hash_bitlen = hash_bitlen; + // + // Sincc the hash table size is limited to a power of 2, the requirement, + // the maximum fill rate, may be satisfied even if the frequency of hash + // table reset is reduced to once every 2 or more iterations without + // changing the hash table size. In that case, reduce the reset frequency. + // + small_hash_reset_interval = 1; + while (1) { + const auto max_visited_nodes = + itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); + if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } + small_hash_reset_interval += 1; + } + break; + } + if (hash_bitlen == 0) { + // + // The size of hash table is determined based on the maximum number of + // nodes that may be visited before the search is completed and the + // maximum fill rate of the hash table. + // + uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); + unsigned min_bitlen = 11; // 2K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + RAFT_EXPECTS(hash_bitlen <= 20, + "hash_bitlen cannot be largen than 20 (1M). You can decrease itopk_size, " + "search_width or max_iterations to reduce the required hashmap size."); + } } - RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size); RAFT_LOG_DEBUG("# parent size = %lu", search_width); RAFT_LOG_DEBUG("# min_iterations = %lu", min_iterations); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index fa71dbaf9..0911d440c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -94,9 +94,10 @@ struct search : search_plan_impl { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk) + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk) { set_params(res); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 678ed0cb4..0eedb8d09 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -622,7 +622,9 @@ __device__ void search_core( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen); + hash_bitlen, + (INDEX_T*) nullptr, + 0); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -749,6 +751,8 @@ __device__ void search_core( graph_degree, local_visited_hashmap_ptr, hash_bitlen, + (INDEX_T*) nullptr, + 0, parent_list_buffer, result_indices_buffer, search_width); From 37e26c1bd2bca4b41ce0510e842a07ea3ead0816 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 5 Dec 2024 07:51:06 -0800 Subject: [PATCH 02/12] fix style --- cpp/src/neighbors/detail/cagra/factory.cuh | 6 ++-- cpp/src/neighbors/detail/cagra/hashmap.hpp | 32 +++++++++---------- .../neighbors/detail/cagra/search_plan.cuh | 19 +++++++---- .../cagra/search_single_cta_kernel-inl.cuh | 4 +-- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 064f880ad..d2ae5c55b 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -57,15 +57,15 @@ class factory { if (plan.algo == search_algo::SINGLE_CTA) { return std::make_unique< single_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else if (plan.algo == search_algo::MULTI_CTA) { return std::make_unique< multi_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else { return std::make_unique< multi_kernel_search::search>( - res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } } }; diff --git a/cpp/src/neighbors/detail/cagra/hashmap.hpp b/cpp/src/neighbors/detail/cagra/hashmap.hpp index 6dbdd5a8a..da736ef5e 100644 --- a/cpp/src/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/src/neighbors/detail/cagra/hashmap.hpp @@ -62,7 +62,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, const uint32_t stride = (key >> bitlen) * 2 + 1; #endif constexpr IdxT hashval_empty = ~static_cast(0); - const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { const IdxT old = atomicCAS(&table[index], hashval_empty, key); if (old == hashval_empty) { @@ -86,19 +86,19 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, template RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key) { - const uint32_t size = get_size(bitlen); - const uint32_t bit_mask = size - 1; + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; #ifdef HASHMAP_LINEAR_PROBING - // Linear probing - IdxT index = (key ^ (key >> bitlen)) & bit_mask; - constexpr uint32_t stride = 1; + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; #else - // Double hashing - IdxT index = key & bit_mask; - const uint32_t stride = (key >> bitlen) * 2 + 1; + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; #endif constexpr IdxT hashval_empty = ~static_cast(0); - const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { const IdxT val = table[index]; if (val == key) { @@ -107,9 +107,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, return 0; } else if (SUPPORT_REMOVE) { // Check if this key has been removed. - if (val == removed_key) { - return 0; - } + if (val == removed_key) { return 0; } } index = (index + stride) & bit_mask; } @@ -119,19 +117,19 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, template RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key) { - const uint32_t size = get_size(bitlen); + const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; #ifdef HASHMAP_LINEAR_PROBING // Linear probing - IdxT index = (key ^ (key >> bitlen)) & bit_mask; + IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; #else // Double hashing - IdxT index = key & bit_mask; + IdxT index = key & bit_mask; const uint32_t stride = (key >> bitlen) * 2 + 1; #endif constexpr IdxT hashval_empty = ~static_cast(0); - const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { // To remove a key, set the MSB to 1. const uint32_t old = atomicCAS(&table[index], key, removed_key); diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 2bbf3d56a..5b6b58a13 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -112,9 +112,13 @@ struct search_plan_impl_base : public search_params { int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t dataset_size, - int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), dataset_size(dataset_size), graph_degree(graph_degree), topk(topk) + search_plan_impl_base( + search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk) + : search_params(params), + dim(dim), + dataset_size(dataset_size), + graph_degree(graph_degree), + topk(topk) { if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); @@ -195,9 +199,9 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t _max_iterations = max_iterations; if (max_iterations == 0) { if (algo == search_algo::MULTI_CTA) { - constexpr uint32_t mc_itopk_size = 32; + constexpr uint32_t mc_itopk_size = 32; constexpr uint32_t mc_search_width = 1; - _max_iterations = mc_itopk_size / mc_search_width; + _max_iterations = mc_itopk_size / mc_search_width; } else { _max_iterations = itopk_size / search_width; } @@ -246,7 +250,7 @@ struct search_plan_impl : public search_plan_impl_base { // shared among CTAs. // const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations); - small_hash_bitlen = 11; // 2K + small_hash_bitlen = 11; // 2K while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { small_hash_bitlen += 1; } @@ -257,7 +261,8 @@ struct search_plan_impl : public search_plan_impl_base { // in each iteration is managed in a separate hash table, which is shared // among the CTAs. // - const auto max_traversed_nodes = mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); + const auto max_traversed_nodes = + mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); unsigned min_bitlen = 11; // 2K if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } hash_bitlen = min_bitlen; diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 0eedb8d09..94c97ed16 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -623,7 +623,7 @@ __device__ void search_core( num_seeds, local_visited_hashmap_ptr, hash_bitlen, - (INDEX_T*) nullptr, + (INDEX_T*)nullptr, 0); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -751,7 +751,7 @@ __device__ void search_core( graph_degree, local_visited_hashmap_ptr, hash_bitlen, - (INDEX_T*) nullptr, + (INDEX_T*)nullptr, 0, parent_list_buffer, result_indices_buffer, From ab1130bc48f1df0ff085eb33f2486af4b281e87f Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 9 Dec 2024 14:44:50 +0100 Subject: [PATCH 03/12] Check if CAGRA search returns enough valid indices during add_nodes Handle the case when the search result contains invalid indices when building the updated graph in add_nodes. For debugging purposes, fail if any invalid indices found; in future, we can replace RAFT_FAIL with RAFT_LOG_WARN to make the add_nodes routine more robust. --- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 33 ++++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index b03b8214b..66da91c57 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -127,6 +127,7 @@ void add_node_core( raft::resource::sync_stream(handle); // Step 2: rank-based reordering + bool search_returned_not_enough_values = false; #pragma omp parallel { std::vector> detourable_node_count_list(base_degree); @@ -135,10 +136,29 @@ void add_node_core( // Count detourable edges for (std::uint32_t i = 0; i < base_degree; i++) { std::uint32_t detourable_node_count = 0; - const auto a_id = host_neighbor_indices(vec_i, i); + // TODO: the invalid indices may be produced by neighbors::cagra::search above. + // This may happen if the search function hasn't returned enough values. + // - A valid reason could be: the index size is smaller than the base degree. + // - A bad reason could be: search iterations is set to a too low value or some + // other problem with the search config. + // - This could also be a bug in the search function + // In the following, we check the indices and assign low priorities to invalid links, + // so that they are not likely to appear in the final graph. + const auto a_id = host_neighbor_indices(vec_i, i); + if (a_id >= idx.size()) { + detourable_node_count_list[i] = std::make_pair(a_id, base_degree); +#pragma omp atomic write + search_returned_not_enough_values = true; + continue; + } for (std::uint32_t j = 0; j < i; j++) { const auto b0_id = host_neighbor_indices(vec_i, j); - assert(b0_id < idx.size()); + if (b0_id >= idx.size()) { +#pragma omp atomic write + search_returned_not_enough_values = true; + detourable_node_count++; + continue; + } for (std::uint32_t k = 0; k < degree; k++) { const auto b1_id = updated_graph(b0_id, k); if (a_id == b1_id) { @@ -160,6 +180,11 @@ void add_node_core( } } } + if (search_returned_not_enough_values) { + RAFT_FAIL( + "CAGRA search returned not enough valid indices to add new nodes to the graph. " + "The resulting graph may contain invalid edges at the new nodes."); + } // Step 3: Add reverse edges const std::uint32_t rev_edge_search_range = degree / 2; @@ -248,7 +273,9 @@ void add_graph_nodes( raft::host_matrix_view updated_graph_view, const cagra::extend_params& params) { - assert(input_updated_dataset_view.extent(0) >= index.size()); + if (input_updated_dataset_view.extent(0) < index.size()) { + RAFT_FAIL("Updated dataset must be not smaller than the previous index state."); + } const std::size_t initial_dataset_size = index.size(); const std::size_t new_dataset_size = input_updated_dataset_view.extent(0); From bedd22492efaec48176f984c6e978b45037ec854 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Fri, 20 Dec 2024 21:20:32 +0900 Subject: [PATCH 04/12] Resolving various issues with the new multi-CTA algorithm --- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 68 +++--- .../neighbors/detail/cagra/device_common.hpp | 2 +- cpp/src/neighbors/detail/cagra/hashmap.hpp | 2 +- .../detail/cagra/search_multi_cta.cuh | 2 +- .../cagra/search_multi_cta_kernel-inl.cuh | 228 +++++++++++++----- .../neighbors/detail/cagra/search_plan.cuh | 6 +- 6 files changed, 211 insertions(+), 97 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 66da91c57..d03fe7b95 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -126,8 +126,24 @@ void add_node_core( raft::resource::get_cuda_stream(handle)); raft::resource::sync_stream(handle); + // Check search results + for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) { + std::uint32_t invalid_edges = 0; + for (std::uint32_t i = 0; i < base_degree; i++) { + if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; } + } + if (invalid_edges > 0) { + RAFT_LOG_WARN( + "Invalid edges found in search results " + "(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)", + (uint64_t)vec_i, + (uint64_t)invalid_edges, + (uint64_t)degree, + (uint64_t)base_degree); + } + } + // Step 2: rank-based reordering - bool search_returned_not_enough_values = false; #pragma omp parallel { std::vector> detourable_node_count_list(base_degree); @@ -136,29 +152,14 @@ void add_node_core( // Count detourable edges for (std::uint32_t i = 0; i < base_degree; i++) { std::uint32_t detourable_node_count = 0; - // TODO: the invalid indices may be produced by neighbors::cagra::search above. - // This may happen if the search function hasn't returned enough values. - // - A valid reason could be: the index size is smaller than the base degree. - // - A bad reason could be: search iterations is set to a too low value or some - // other problem with the search config. - // - This could also be a bug in the search function - // In the following, we check the indices and assign low priorities to invalid links, - // so that they are not likely to appear in the final graph. - const auto a_id = host_neighbor_indices(vec_i, i); + const auto a_id = host_neighbor_indices(vec_i, i); if (a_id >= idx.size()) { - detourable_node_count_list[i] = std::make_pair(a_id, base_degree); -#pragma omp atomic write - search_returned_not_enough_values = true; + detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1); continue; } for (std::uint32_t j = 0; j < i; j++) { const auto b0_id = host_neighbor_indices(vec_i, j); - if (b0_id >= idx.size()) { -#pragma omp atomic write - search_returned_not_enough_values = true; - detourable_node_count++; - continue; - } + if (b0_id >= idx.size()) { continue; } for (std::uint32_t k = 0; k < degree; k++) { const auto b1_id = updated_graph(b0_id, k); if (a_id == b1_id) { @@ -169,6 +170,7 @@ void add_node_core( } detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count); } + std::sort(detourable_node_count_list.begin(), detourable_node_count_list.end(), [&](const std::pair a, const std::pair b) { @@ -180,11 +182,6 @@ void add_node_core( } } } - if (search_returned_not_enough_values) { - RAFT_FAIL( - "CAGRA search returned not enough valid indices to add new nodes to the graph. " - "The resulting graph may contain invalid edges at the new nodes."); - } // Step 3: Add reverse edges const std::uint32_t rev_edge_search_range = degree / 2; @@ -195,13 +192,18 @@ void add_node_core( const auto target_new_node_id = old_size + batch.offset() + vec_i; for (std::size_t i = 0; i < num_rev_edges; i++) { const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i); - + if (target_node_id >= new_size) { + RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id); + } IdxT replace_id = new_size; IdxT replace_id_j = 0; std::size_t replace_num_incoming_edges = 0; for (std::int32_t j = degree - 1; j >= static_cast(rev_edge_search_range); j--) { - const auto neighbor_id = updated_graph(target_node_id, j); + const auto neighbor_id = updated_graph(target_node_id, j); + if (neighbor_id >= new_size) { + RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id); + } const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id); if (num_incoming_edges > replace_num_incoming_edges) { // Check duplication @@ -220,10 +222,6 @@ void add_node_core( replace_id_j = j; } } - if (replace_id >= new_size) { - std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id); - return; - } updated_graph(target_node_id, replace_id_j) = target_new_node_id; rev_edges[i] = replace_id; } @@ -235,13 +233,15 @@ void add_node_core( const auto rank_based_list_ptr = updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree; const auto rev_edges_return_list_ptr = rev_edges.data(); - while (num_add < degree) { + while ((num_add < degree) && + ((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) { const auto node_list_ptr = interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr; auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i; const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges; for (; node_list_index < max_node_list_index; node_list_index++) { const auto candidate = node_list_ptr[node_list_index]; + if (candidate >= new_size) { continue; } // Check duplication bool dup = false; for (std::uint32_t j = 0; j < num_add; j++) { @@ -258,6 +258,12 @@ void add_node_core( } interleave_switch = 1 - interleave_switch; } + if (num_add < degree) { + RAFT_FAIL("Number of edges is not enough (target_new_node_id:%u, num_add:%u, degree:%u)", + target_new_node_id, + num_add, + degree); + } for (std::uint32_t i = 0; i < degree; i++) { updated_graph(target_new_node_id, i) = temp[i]; } diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 9dcc5123b..c20d58994 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -185,7 +185,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const uint32_t search_width) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; - constexpr IndexT invalid_index = raft::upper_bound(); + constexpr IndexT invalid_index = ~static_cast(0); // Read child indices of parents from knn graph and check if the distance // computaiton is necessary. diff --git a/cpp/src/neighbors/detail/cagra/hashmap.hpp b/cpp/src/neighbors/detail/cagra/hashmap.hpp index da736ef5e..652e1db22 100644 --- a/cpp/src/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/src/neighbors/detail/cagra/hashmap.hpp @@ -40,7 +40,7 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table, { if (threadIdx.x < FIRST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { - table[i] = utils::get_max_value(); + table[i] = ~static_cast(0); } } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 8d425ca67..8a97173fa 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -127,7 +127,7 @@ struct search : public search_plan_impl::value; - const unsigned warp_id = threadIdx.x / 32; + constexpr INDEX_T invalid_index = ~static_cast(0); + + const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; // Initialize if (lane_id < num_parents) { next_parent_indices[lane_id] = ~static_cast(0); } - INDEX_T index = ~static_cast(0); + INDEX_T index = invalid_index; if (lane_id < num_itopk) { index = itopk_indices[lane_id]; } int is_candidate = 0; - if ((index & index_msb_1_mask) == 0) { + if ((index != invalid_index) && ((index & index_msb_1_mask) == 0)) { if (hashmap::search(hash_ptr, hash_bitlen, index)) { // Deactivate nodes that have already been used by other CTAs. - index = ~static_cast(0); - itopk_indices[lane_id] = index; + itopk_indices[lane_id] = invalid_index; itopk_distances[lane_id] = utils::get_max_value(); + index = invalid_index; } else { is_candidate = 1; } @@ -102,7 +104,7 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( is_found = 1; } else { // Deactivate the node since it has been used by other CTA. - index = ~static_cast(0); + index = invalid_index; itopk_distances[lane_id] = utils::get_max_value(); } itopk_indices[lane_id] = index; @@ -134,7 +136,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( val[i] = indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = ~static_cast(0); } } /* Warp Sort */ @@ -143,12 +145,45 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; if (j < num_elements) { - indices[j] = val[i]; - if (j < num_itopk) { distances[j] = key[i]; } + distances[j] = key[i]; + indices[j] = val[i]; } } } +template +RAFT_DEVICE_INLINE_FUNCTION void move_valid_entries_to_head( + INDEX_T* indices, // [num_elements] + DISTANCE_T* distances, // [num_elements] + const uint32_t num_elements // (*) num_elements must be multiple of 32 +) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr INDEX_T invalid_index = ~static_cast(0); + const uint32_t warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + const uint32_t lane_id = threadIdx.x % 32; + uint32_t offset = 0; + for (uint32_t i = lane_id; i < num_elements; i += 32) { + auto index = indices[i]; + auto distance = distances[i]; + bool is_valid = (index != invalid_index); + const auto mask = __ballot_sync(0xffffffff, is_valid); + const auto j = offset + __popc(mask & ((1 << lane_id) - 1)); + if ((index != invalid_index) && (j < i)) { + indices[j] = index; + distances[j] = distance; + } + offset += __popc(mask); + __syncwarp(); + } + for (uint32_t i = offset + lane_id; i < num_elements; i += 32) { + indices[i] = invalid_index; + distances[i] = utils::get_max_value(); + } + __syncwarp(); +} + // // multiple CTAs per single query // @@ -225,18 +260,17 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( auto* __restrict__ parent_indices_buffer = reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); -#if 0 - /* debug */ - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - result_indices_buffer[i] = utils::get_max_value(); - result_distances_buffer[i] = utils::get_max_value(); - } -#endif - - hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); - INDEX_T* const local_traversed_hashmap_ptr = traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); + + constexpr INDEX_T invalid_index = ~static_cast(0); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); __syncthreads(); _CLK_REC(clk_init); @@ -263,46 +297,77 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( __syncthreads(); _CLK_REC(clk_compute_1st_distance); - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - uint32_t iter = 0; + uint32_t iter = 0; while (1) { - // Topk with bitonic sort (1st warp only) _CLK_START(); + // Topk with bitonic sort (1st warp only) topk_by_bitonic_sort(result_distances_buffer, result_indices_buffer, itopk_size + (search_width * graph_degree), itopk_size); - _CLK_REC(clk_topk); __syncthreads(); + _CLK_REC(clk_topk); - if (iter + 1 == max_iteration) { break; } - - // Remove entries kicked out of the itopk list from the traversed hash table. - for (unsigned i = threadIdx.x; i < search_width * graph_degree; i += blockDim.x) { - INDEX_T index = result_indices_buffer[itopk_size + i]; - if ((index & index_msb_1_mask) == 0 || (index == ~static_cast(0))) { continue; } - index &= ~index_msb_1_mask; - hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); - } + if (iter + 1 >= max_iteration) { break; } - // Pick up next parents (1st warp only) _CLK_START(); - pickup_next_parents(parent_indices_buffer, - search_width, - result_indices_buffer, - result_distances_buffer, - itopk_size, - local_traversed_hashmap_ptr, - traversed_hash_bitlen); - _CLK_REC(clk_pickup_parents); + if (threadIdx.x < 32) { + // [1st warp] Pick up next parents + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); +#if 0 + if (parent_indices_buffer[0] == invalid_index) { + // Try again if no parent is found + move_valid_entries_to_head(result_indices_buffer, + result_distances_buffer, + result_buffer_size_32); + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } +#endif + } else { + // [Other warps] Reset visited hashmap + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); + } __syncthreads(); + _CLK_REC(clk_pickup_parents); - if ((parent_indices_buffer[0] == ~static_cast(0)) && (iter >= min_iteration)) { - break; + if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } + + if (threadIdx.x < 32) { + // [1st warp] Restore visited hashmap by putting itopk indices in it. + for (unsigned i = threadIdx.x; i < itopk_size; i += 32) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + } else { + // [Other warps] Remove entries kicked out of the itopk list from the + // traversed hash table. + for (unsigned i = threadIdx.x - 32; i < search_width * graph_degree; i += blockDim.x - 32) { + INDEX_T index = result_indices_buffer[itopk_size + i]; + if (index == invalid_index) { continue; } + if (index & index_msb_1_mask) { + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + } + } } + __syncthreads(); - // compute the norms between child nodes and query node _CLK_START(); + // compute the norms between child nodes and query node device::compute_distance_to_child_nodes(result_indices_buffer + itopk_size, result_distances_buffer + itopk_size, *dataset_desc, @@ -315,15 +380,12 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( parent_indices_buffer, result_indices_buffer, search_width); - _CLK_REC(clk_compute_distance); __syncthreads(); + _CLK_REC(clk_compute_distance); // Filtering if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { if (parent_indices_buffer[p] != invalid_index) { const auto parent_id = @@ -341,19 +403,65 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( iter++; } - for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - INDEX_T index = result_indices_buffer[i]; - DISTANCE_T distance = result_distances_buffer[i]; - if (index & index_msb_1_mask) { - index &= ~index_msb_1_mask; // clear most significant bit - } else { - // This entry has not been used as parent, so deactivate this. - index = ~static_cast(0); - distance = utils::get_max_value(); + // Filtering + if constexpr (!std::is_same::value) { + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + if (!sample_filter(query_id, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + } + + // Output search results (1st warp only). + if (threadIdx.x < 32) { + uint32_t offset = 0; + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { + INDEX_T index = result_indices_buffer[i]; + bool is_valid = false; + if (index != invalid_index) { + if (index & index_msb_1_mask) { + is_valid = true; + index &= ~index_msb_1_mask; + } else if (hashmap::insert( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + // If a node that is not used as a parent can be inserted into + // the traversed hash table, it is considered a valid result. + is_valid = true; + } + } + const auto mask = __ballot_sync(0xffffffff, is_valid); + if (is_valid) { + const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); + if (j < itopk_size) { + uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = index & ~index_msb_1_mask; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = result_distances_buffer[i]; + } + } else if ((index & index_msb_1_mask) == 0) { + // If a node that was successfully inserted in the traversed + // hash table is not output as a result, the hash table is + // restored using hash remove. + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); + } + } + offset += __popc(mask); + if (offset >= itopk_size) break; + } + // If the number of outputs is insufficient, fill in with invalid results. + for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { + uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = invalid_index; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = utils::get_max_value(); + } } - result_indices_ptr[j] = index; - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = distance; } } if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 5b6b58a13..7300c89c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -247,10 +247,10 @@ struct search_plan_impl : public search_plan_impl_base { // [visited_hash_table] // In the multi CTA algo, which node has been visited is managed in a hash // table that each CTA has in the shared memory. This hash table is not - // shared among CTAs. + // shared among CTAs. This hash table is reset and restored in each iteration. // - const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations); - small_hash_bitlen = 11; // 2K + const uint32_t max_visited_nodes = mc_itopk_size + graph_degree; + small_hash_bitlen = 8; // 256 while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { small_hash_bitlen += 1; } From ea8c273f3142bf649c76da63541f9d5b03059d39 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 23 Dec 2024 16:17:38 +0900 Subject: [PATCH 05/12] Add comments in add_nodes.cuh --- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index d03fe7b95..f65da7a94 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -154,6 +154,9 @@ void add_node_core( std::uint32_t detourable_node_count = 0; const auto a_id = host_neighbor_indices(vec_i, i); if (a_id >= idx.size()) { + // If the node ID is not valid, the number of detours is increased + // to a value greater than the maximum, so that the edge to that + // node is not selected as much as possible. detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1); continue; } From 5025481a492776ec0f0f90f6f990dde53b42e51e Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 23 Dec 2024 21:53:13 +0900 Subject: [PATCH 06/12] Limit tht number of warnings output --- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 21 +++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index f65da7a94..0749daa0f 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -127,21 +127,28 @@ void add_node_core( raft::resource::sync_stream(handle); // Check search results + int num_warnings = 0; for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) { std::uint32_t invalid_edges = 0; for (std::uint32_t i = 0; i < base_degree; i++) { if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; } } if (invalid_edges > 0) { - RAFT_LOG_WARN( - "Invalid edges found in search results " - "(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)", - (uint64_t)vec_i, - (uint64_t)invalid_edges, - (uint64_t)degree, - (uint64_t)base_degree); + if (num_warnings < 3) { + RAFT_LOG_WARN( + "Invalid edges found in search results " + "(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)", + (uint64_t)vec_i, + (uint64_t)invalid_edges, + (uint64_t)degree, + (uint64_t)base_degree); + } + num_warnings += 1; } } + if (num_warnings > 0) { + RAFT_LOG_WARN("The number of queries that contain invalid search results: %d", num_warnings); + } // Step 2: rank-based reordering #pragma omp parallel From b61126a07af7cce8bbe8bb4e7a601f42166b5718 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Wed, 25 Dec 2024 16:19:37 +0900 Subject: [PATCH 07/12] Avoid invalid results in search results as much as possible --- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 7 +- .../detail/cagra/search_multi_cta.cuh | 2 +- .../cagra/search_multi_cta_kernel-inl.cuh | 114 ++++++++++-------- 3 files changed, 69 insertions(+), 54 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 0749daa0f..189cda20a 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -127,14 +127,15 @@ void add_node_core( raft::resource::sync_stream(handle); // Check search results - int num_warnings = 0; + constexpr int max_warnings = 3; + int num_warnings = 0; for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) { std::uint32_t invalid_edges = 0; for (std::uint32_t i = 0; i < base_degree; i++) { if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; } } if (invalid_edges > 0) { - if (num_warnings < 3) { + if (num_warnings < max_warnings) { RAFT_LOG_WARN( "Invalid edges found in search results " "(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)", @@ -146,7 +147,7 @@ void add_node_core( num_warnings += 1; } } - if (num_warnings > 0) { + if (num_warnings > max_warnings) { RAFT_LOG_WARN("The number of queries that contain invalid search results: %d", num_warnings); } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 8a97173fa..9da591147 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -120,7 +120,7 @@ struct search : public search_plan_impl AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); // constexpr unsigned max_result_buffer_size = 256; diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index a42a46e19..ee9b9ff95 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -78,6 +78,7 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( int is_candidate = 0; if ((index != invalid_index) && ((index & index_msb_1_mask) == 0)) { +#if 0 if (hashmap::search(hash_ptr, hash_bitlen, index)) { // Deactivate nodes that have already been used by other CTAs. itopk_indices[lane_id] = invalid_index; @@ -86,6 +87,9 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( } else { is_candidate = 1; } +#else + is_candidate = 1; +#endif } uint32_t num_next_parents = 0; @@ -116,12 +120,9 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( } template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( - float* distances, // [num_elements] - INDEX_T* indices, // [num_elements] - const uint32_t num_elements, - const uint32_t num_itopk // num_itopk <= num_elements -) +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements] + INDEX_T* indices, // [num_elements] + const uint32_t num_elements) { const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } @@ -239,11 +240,11 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( extern __shared__ uint8_t smem[]; // Layout of result_buffer - // +----------------+-------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | - // | | | upto 32 | - // +----------------+-------------------------------+---------+ - // |<--- result_buffer_size --->| + // +----------------+---------+-------------------------------+ + // | internal_top_k | padding | neighbors of parent nodes | + // | | upto 32 | | + // +----------------+---------+-------------------------------+ + // |<--- result_buffer_size_32 --->| const auto result_buffer_size = itopk_size + (search_width * graph_degree); const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); assert(result_buffer_size_32 <= MAX_ELEMENTS); @@ -283,7 +284,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( device::compute_distance_to_random_nodes(result_indices_buffer, result_distances_buffer, *dataset_desc, - result_buffer_size, + graph_degree * search_width, num_distilation, rand_xor_mask, local_seed_ptr, @@ -300,11 +301,22 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( uint32_t iter = 0; while (1) { _CLK_START(); + // Checks the state of the node in the result buffer from the previous + // iteration, and if it cannot be used as a parent node, it is deactivated. + for (uint32_t i = threadIdx.x; i < result_buffer_size_32 - (graph_degree * search_width); + i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index || index & index_msb_1_mask) { continue; } + if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + // Topk with bitonic sort (1st warp only) - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); + topk_by_bitonic_sort( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); __syncthreads(); _CLK_REC(clk_topk); @@ -320,20 +332,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( itopk_size, local_traversed_hashmap_ptr, traversed_hash_bitlen); -#if 0 - if (parent_indices_buffer[0] == invalid_index) { - // Try again if no parent is found - move_valid_entries_to_head(result_indices_buffer, - result_distances_buffer, - result_buffer_size_32); - pickup_next_parents(parent_indices_buffer, - search_width, - result_indices_buffer, - result_distances_buffer, - itopk_size, - local_traversed_hashmap_ptr, - traversed_hash_bitlen); - } +#if 1 + if (parent_indices_buffer[0] == invalid_index) { + // Try again if no parent is found + move_valid_entries_to_head( + result_indices_buffer, result_distances_buffer, result_buffer_size_32); + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } #endif } else { // [Other warps] Reset visited hashmap @@ -355,12 +366,15 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( } else { // [Other warps] Remove entries kicked out of the itopk list from the // traversed hash table. - for (unsigned i = threadIdx.x - 32; i < search_width * graph_degree; i += blockDim.x - 32) { - INDEX_T index = result_indices_buffer[itopk_size + i]; + for (unsigned i = itopk_size + threadIdx.x - 32; i < result_buffer_size_32; + i += blockDim.x - 32) { + INDEX_T index = result_indices_buffer[i]; if (index == invalid_index) { continue; } if (index & index_msb_1_mask) { hashmap::remove( local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); } } } @@ -368,18 +382,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( _CLK_START(); // compute the norms between child nodes and query node - device::compute_distance_to_child_nodes(result_indices_buffer + itopk_size, - result_distances_buffer + itopk_size, - *dataset_desc, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - visited_hash_bitlen, - local_traversed_hashmap_ptr, - traversed_hash_bitlen, - parent_indices_buffer, - result_indices_buffer, - search_width); + device::compute_distance_to_child_nodes( + result_indices_buffer + result_buffer_size_32 - (graph_degree * search_width), + result_distances_buffer + result_buffer_size_32 - (graph_degree * search_width), + *dataset_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + search_width); __syncthreads(); _CLK_REC(clk_compute_distance); @@ -428,7 +443,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( if (index & index_msb_1_mask) { is_valid = true; index &= ~index_msb_1_mask; - } else if (hashmap::insert( + } else if ((offset < itopk_size) && + hashmap::insert( local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { // If a node that is not used as a parent can be inserted into // the traversed hash table, it is considered a valid result. @@ -444,15 +460,13 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( if (result_distances_ptr != nullptr) { result_distances_ptr[k] = result_distances_buffer[i]; } - } else if ((index & index_msb_1_mask) == 0) { - // If a node that was successfully inserted in the traversed - // hash table is not output as a result, the hash table is - // restored using hash remove. + } else { + // If it is valid and registered in the traversed hash table but is + // not output as a result, it is removed from the hash table. hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); } } offset += __popc(mask); - if (offset >= itopk_size) break; } // If the number of outputs is insufficient, fill in with invalid results. for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { From 588bd0c98fca84ae0533a6b76a303e9c456f5a6d Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 30 Dec 2024 15:07:20 +0900 Subject: [PATCH 08/12] Improve the accuracy of the new multi-CTA algo by revising the usase of result_buffer --- .../neighbors/detail/cagra/device_common.hpp | 33 ++- .../detail/cagra/search_multi_cta.cuh | 11 +- .../cagra/search_multi_cta_kernel-inl.cuh | 217 ++++++++---------- .../neighbors/detail/cagra/search_plan.cuh | 2 +- 4 files changed, 137 insertions(+), 126 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index c20d58994..0e004e233 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -166,7 +166,10 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( } } -template +template RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( IndexT* __restrict__ result_child_indices_ptr, DistanceT* __restrict__ result_child_distances_ptr, @@ -182,7 +185,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const uint32_t traversed_hash_bitlen, const IndexT* __restrict__ parent_indices, const IndexT* __restrict__ internal_topk_list, - const uint32_t search_width) + const uint32_t search_width, + IndexT* __restrict__ temp_indices_ptr = nullptr, + int* __restrict__ result_position = nullptr) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; constexpr IndexT invalid_index = ~static_cast(0); @@ -207,7 +212,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( child_id = invalid_index; } } - result_child_indices_ptr[i] = child_id; + if (STATIC_RESULT_POSITION) { + result_child_indices_ptr[i] = child_id; + } else { + temp_indices_ptr[i] = child_id; + } } __syncthreads(); @@ -219,8 +228,10 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const auto args = dataset_desc.args.load(); const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { - const bool valid_i = i < num_k; - const auto child_id = valid_i ? result_child_indices_ptr[i] : invalid_index; + const bool valid_i = i < num_k; + const auto child_id = + valid_i ? (STATIC_RESULT_POSITION ? result_child_indices_ptr[i] : temp_indices_ptr[i]) + : invalid_index; // We should be calling `dataset_desc.compute_distance(..)` here as follows: // > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index); @@ -230,9 +241,19 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( (child_id != invalid_index) ? compute_distance(args, child_id) : (lead_lane ? raft::upper_bound() : 0), team_size_bits); + __syncwarp(); // Store the distance - if (valid_i && lead_lane) { result_child_distances_ptr[i] = child_dist; } + if (valid_i && lead_lane) { + if (STATIC_RESULT_POSITION) { + result_child_distances_ptr[i] = child_dist; + } else if (child_id != invalid_index) { + // Only valid results are stored in order from the back of the buffer + int j = atomicSub(result_position, 1) - 1; + result_child_indices_ptr[j] = child_id; + result_child_distances_ptr[j] = child_dist; + } + } } } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 9da591147..d21a0407c 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -126,10 +126,13 @@ struct search : public search_plan_impl -RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( - INDEX_T* const next_parent_indices, // [num_parents] - const uint32_t num_parents, - INDEX_T* const itopk_indices, // [num_itopk] - DISTANCE_T* const itopk_distances, // [num_itopk] - const uint32_t num_itopk, // (*) num_itopk <= 32 +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parent( + INDEX_T* const next_parent_indices, + INDEX_T* const itopk_indices, // [itopk_size * 2] + DISTANCE_T* const itopk_distances, // [itopk_size * 2] INDEX_T* const hash_ptr, const uint32_t hash_bitlen) { + constexpr uint32_t itopk_size = 32; constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; constexpr INDEX_T invalid_index = ~static_cast(0); const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; + if (threadIdx.x == 0) { next_parent_indices[0] = invalid_index; } + __syncwarp(); - // Initialize - if (lane_id < num_parents) { next_parent_indices[lane_id] = ~static_cast(0); } - INDEX_T index = invalid_index; - if (lane_id < num_itopk) { index = itopk_indices[lane_id]; } - - int is_candidate = 0; - if ((index != invalid_index) && ((index & index_msb_1_mask) == 0)) { -#if 0 - if (hashmap::search(hash_ptr, hash_bitlen, index)) { - // Deactivate nodes that have already been used by other CTAs. - itopk_indices[lane_id] = invalid_index; - itopk_distances[lane_id] = utils::get_max_value(); - index = invalid_index; + int j = -1; + for (unsigned i = threadIdx.x; i < itopk_size * 2; i += 32) { + INDEX_T index = itopk_indices[i]; + int is_invalid = 0; + int is_candidate = 0; + if (index == invalid_index) { + is_invalid = 1; + } else if (index & index_msb_1_mask) { } else { is_candidate = 1; } -#else - is_candidate = 1; -#endif - } - uint32_t num_next_parents = 0; - while (num_next_parents < num_parents) { - const uint32_t ballot_mask = __ballot_sync(0xffffffff, is_candidate); - int num_candidates = __popc(ballot_mask); - if (num_candidates == 0) { return; } - int is_found = 0; - if (is_candidate) { - const auto candidate_id = __popc(ballot_mask & ((1 << lane_id) - 1)); - if (candidate_id == 0) { + const auto ballot_mask = __ballot_sync(0xffffffff, is_candidate); + const auto candidate_id = __popc(ballot_mask & ((1 << threadIdx.x) - 1)); + for (int k = 0; k < __popc(ballot_mask); k++) { + int flag_done = 0; + if (is_candidate && candidate_id == k) { + is_candidate = 0; if (hashmap::insert(hash_ptr, hash_bitlen, index)) { // Use this candidate as next parent - next_parent_indices[num_next_parents] = lane_id; index |= index_msb_1_mask; // set most significant bit as used node - is_found = 1; + if (i < itopk_size) { + next_parent_indices[0] = i; + itopk_indices[i] = index; + } else { + next_parent_indices[0] = j; + // Move the next parent node from i-th position to j-th position + itopk_indices[j] = index; + itopk_distances[j] = itopk_distances[i]; + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + } + flag_done = 1; } else { // Deactivate the node since it has been used by other CTA. - index = invalid_index; - itopk_distances[lane_id] = utils::get_max_value(); + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + is_invalid = 1; } - itopk_indices[lane_id] = index; - is_candidate = 0; } + if (__any_sync(0xffffffff, (flag_done > 0))) { return; } + } + if (i < itopk_size) { + j = 31 - __clz(__ballot_sync(0xffffffff, is_invalid)); + if (j < 0) { return; } } - if (__ballot_sync(0xffffffff, is_found)) { num_next_parents += 1; } } } @@ -207,7 +208,6 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] const uint32_t traversed_hash_bitlen, const uint32_t itopk_size, - const uint32_t search_width, const uint32_t min_iteration, const uint32_t max_iteration, uint32_t* const num_executed_iterations, /* stats */ @@ -240,12 +240,12 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( extern __shared__ uint8_t smem[]; // Layout of result_buffer - // +----------------+---------+-------------------------------+ - // | internal_top_k | padding | neighbors of parent nodes | - // | | upto 32 | | - // +----------------+---------+-------------------------------+ - // |<--- result_buffer_size_32 --->| - const auto result_buffer_size = itopk_size + (search_width * graph_degree); + // +----------------+---------+---------------------------+ + // | internal_top_k | padding | neighbors of parent nodes | + // | | upto 32 | | + // +----------------+---------+---------------------------+ + // |<--- result_buffer_size_32 --->| + const auto result_buffer_size = itopk_size + graph_degree; const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); assert(result_buffer_size_32 <= MAX_ELEMENTS); @@ -258,8 +258,11 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto* __restrict__ local_visited_hashmap_ptr = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ parent_indices_buffer = + auto* __restrict__ temp_indices_buffer = reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(temp_indices_buffer + graph_degree); + auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); INDEX_T* const local_traversed_hashmap_ptr = traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); @@ -284,7 +287,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( device::compute_distance_to_random_nodes(result_indices_buffer, result_distances_buffer, *dataset_desc, - graph_degree * search_width, + graph_degree, num_distilation, rand_xor_mask, local_seed_ptr, @@ -301,51 +304,24 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( uint32_t iter = 0; while (1) { _CLK_START(); - // Checks the state of the node in the result buffer from the previous - // iteration, and if it cannot be used as a parent node, it is deactivated. - for (uint32_t i = threadIdx.x; i < result_buffer_size_32 - (graph_degree * search_width); - i += blockDim.x) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index || index & index_msb_1_mask) { continue; } - if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } + if (threadIdx.x < 32) { + // [1st warp] Topk with bitonic sort + topk_by_bitonic_sort( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); } __syncthreads(); - - // Topk with bitonic sort (1st warp only) - topk_by_bitonic_sort( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - __syncthreads(); _CLK_REC(clk_topk); if (iter + 1 >= max_iteration) { break; } _CLK_START(); if (threadIdx.x < 32) { - // [1st warp] Pick up next parents - pickup_next_parents(parent_indices_buffer, - search_width, - result_indices_buffer, - result_distances_buffer, - itopk_size, - local_traversed_hashmap_ptr, - traversed_hash_bitlen); -#if 1 - if (parent_indices_buffer[0] == invalid_index) { - // Try again if no parent is found - move_valid_entries_to_head( - result_indices_buffer, result_distances_buffer, result_buffer_size_32); - pickup_next_parents(parent_indices_buffer, - search_width, - result_indices_buffer, - result_distances_buffer, - itopk_size, - local_traversed_hashmap_ptr, - traversed_hash_bitlen); - } -#endif + // [1st warp] Pick up a next parent + pickup_next_parent(parent_indices_buffer, + result_indices_buffer, + result_distances_buffer, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); } else { // [Other warps] Reset visited hashmap hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); @@ -355,36 +331,33 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } - if (threadIdx.x < 32) { - // [1st warp] Restore visited hashmap by putting itopk indices in it. - for (unsigned i = threadIdx.x; i < itopk_size; i += 32) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index) { continue; } - index &= ~index_msb_1_mask; - hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); - } - } else { - // [Other warps] Remove entries kicked out of the itopk list from the - // traversed hash table. - for (unsigned i = itopk_size + threadIdx.x - 32; i < result_buffer_size_32; - i += blockDim.x - 32) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index) { continue; } - if (index & index_msb_1_mask) { - hashmap::remove( - local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } + _CLK_START(); + // Restore visited hashmap by putting nodes on result buffer in it. + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + // Remove nodes kicked out of the itopk list from the traversed hash table. + for (unsigned i = itopk_size + threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + if (index & index_msb_1_mask) { + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); } } + // Initialize buffer for compute_distance_to_child_nodes. + if (threadIdx.x == blockDim.x - 1) { result_position[0] = result_buffer_size_32; } __syncthreads(); - _CLK_START(); - // compute the norms between child nodes and query node - device::compute_distance_to_child_nodes( - result_indices_buffer + result_buffer_size_32 - (graph_degree * search_width), - result_distances_buffer + result_buffer_size_32 - (graph_degree * search_width), + // Compute the norms between child nodes and query node + device::compute_distance_to_child_nodes( + result_indices_buffer, + result_distances_buffer, *dataset_desc, knn_graph, graph_degree, @@ -394,14 +367,29 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( traversed_hash_bitlen, parent_indices_buffer, result_indices_buffer, - search_width); + 1, + temp_indices_buffer, + result_position); + __syncthreads(); + + // Check the state of the nodes in the result buffer which were not updated + // by the compute_distance_to_child_nodes above, and if it cannot be used as + // a parent node, it is deactivated. + for (uint32_t i = threadIdx.x; i < result_position[0]; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index || index & index_msb_1_mask) { continue; } + if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } __syncthreads(); _CLK_REC(clk_compute_distance); // Filtering if constexpr (!std::is_same::value) { - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + for (unsigned p = threadIdx.x; p < 1; p += blockDim.x) { if (parent_indices_buffer[p] != invalid_index) { const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; @@ -615,7 +603,6 @@ void select_and_run(const dataset_descriptor_host& dat traversed_hashmap_ptr, traversed_hash_bitlen, ps.itopk_size, - ps.search_width, ps.min_iterations, ps.max_iterations, num_executed_iterations, diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 7300c89c6..0246d5f67 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -249,7 +249,7 @@ struct search_plan_impl : public search_plan_impl_base { // table that each CTA has in the shared memory. This hash table is not // shared among CTAs. This hash table is reset and restored in each iteration. // - const uint32_t max_visited_nodes = mc_itopk_size + graph_degree; + const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * 2); small_hash_bitlen = 8; // 256 while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { small_hash_bitlen += 1; From 228a1aebb73071491ffac22e42e75ef90df27aaf Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 6 Jan 2025 17:28:53 +0900 Subject: [PATCH 09/12] Reduce the number of shared memory access --- .../neighbors/detail/cagra/device_common.hpp | 28 +++++++------------ .../detail/cagra/search_multi_cta.cuh | 1 - .../cagra/search_multi_cta_kernel-inl.cuh | 25 +++++++---------- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 0e004e233..2c2c67fcd 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -186,8 +186,8 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const IndexT* __restrict__ parent_indices, const IndexT* __restrict__ internal_topk_list, const uint32_t search_width, - IndexT* __restrict__ temp_indices_ptr = nullptr, - int* __restrict__ result_position = nullptr) + int* __restrict__ result_position = nullptr, + const int max_result_position = 0) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; constexpr IndexT invalid_index = ~static_cast(0); @@ -214,8 +214,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( } if (STATIC_RESULT_POSITION) { result_child_indices_ptr[i] = child_id; - } else { - temp_indices_ptr[i] = child_id; + } else if (child_id != invalid_index) { + int j = atomicSub(result_position, 1) - 1; + result_child_indices_ptr[j] = child_id; } } __syncthreads(); @@ -227,11 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const auto compute_distance = dataset_desc.compute_distance_impl; const auto args = dataset_desc.args.load(); const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; + const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { - const bool valid_i = i < num_k; - const auto child_id = - valid_i ? (STATIC_RESULT_POSITION ? result_child_indices_ptr[i] : temp_indices_ptr[i]) - : invalid_index; + const auto j = i + ofst; + const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); + const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; // We should be calling `dataset_desc.compute_distance(..)` here as follows: // > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index); @@ -244,16 +245,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( __syncwarp(); // Store the distance - if (valid_i && lead_lane) { - if (STATIC_RESULT_POSITION) { - result_child_distances_ptr[i] = child_dist; - } else if (child_id != invalid_index) { - // Only valid results are stored in order from the back of the buffer - int j = atomicSub(result_position, 1) - 1; - result_child_indices_ptr[j] = child_id; - result_child_distances_ptr[j] = child_dist; - } - } + if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } } } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index d21a0407c..ac361e800 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -130,7 +130,6 @@ struct search : public search_plan_impl(result_indices_buffer + result_buffer_size_32); auto* __restrict__ local_visited_hashmap_ptr = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ temp_indices_buffer = - reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); auto* __restrict__ parent_indices_buffer = - reinterpret_cast(temp_indices_buffer + graph_degree); + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); INDEX_T* const local_traversed_hashmap_ptr = @@ -332,22 +330,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } _CLK_START(); - // Restore visited hashmap by putting nodes on result buffer in it. for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { INDEX_T index = result_indices_buffer[i]; if (index == invalid_index) { continue; } - index &= ~index_msb_1_mask; - hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); - } - // Remove nodes kicked out of the itopk list from the traversed hash table. - for (unsigned i = itopk_size + threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index) { continue; } - if (index & index_msb_1_mask) { + if ((i >= itopk_size) && (index & index_msb_1_mask)) { + // Remove nodes kicked out of the itopk list from the traversed hash table. hashmap::remove( local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); result_indices_buffer[i] = invalid_index; result_distances_buffer[i] = utils::get_max_value(); + } else { + // Restore visited hashmap by putting nodes on result buffer in it. + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); } } // Initialize buffer for compute_distance_to_child_nodes. @@ -368,9 +363,9 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( parent_indices_buffer, result_indices_buffer, 1, - temp_indices_buffer, - result_position); - __syncthreads(); + result_position, + result_buffer_size_32); + // __syncthreads(); // Check the state of the nodes in the result buffer which were not updated // by the compute_distance_to_child_nodes above, and if it cannot be used as From 776f2f5d0b2cf84bfa0cc20794b10c8bcb562f5a Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Wed, 8 Jan 2025 16:18:14 +0900 Subject: [PATCH 10/12] Remove unused code --- .../cagra/search_multi_cta_kernel-inl.cuh | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 4a04cc1b9..a68156567 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -153,39 +153,6 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num } } -template -RAFT_DEVICE_INLINE_FUNCTION void move_valid_entries_to_head( - INDEX_T* indices, // [num_elements] - DISTANCE_T* distances, // [num_elements] - const uint32_t num_elements // (*) num_elements must be multiple of 32 -) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - constexpr INDEX_T invalid_index = ~static_cast(0); - const uint32_t warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const uint32_t lane_id = threadIdx.x % 32; - uint32_t offset = 0; - for (uint32_t i = lane_id; i < num_elements; i += 32) { - auto index = indices[i]; - auto distance = distances[i]; - bool is_valid = (index != invalid_index); - const auto mask = __ballot_sync(0xffffffff, is_valid); - const auto j = offset + __popc(mask & ((1 << lane_id) - 1)); - if ((index != invalid_index) && (j < i)) { - indices[j] = index; - distances[j] = distance; - } - offset += __popc(mask); - __syncwarp(); - } - for (uint32_t i = offset + lane_id; i < num_elements; i += 32) { - indices[i] = invalid_index; - distances[i] = utils::get_max_value(); - } - __syncwarp(); -} - // // multiple CTAs per single query // From 192c0a9af13955fa0017626091a37b1c7f217f5f Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Fri, 10 Jan 2025 18:25:03 +0900 Subject: [PATCH 11/12] Update cpp/src/neighbors/detail/cagra/device_common.hpp Co-authored-by: Artem M. Chirkin <9253178+achirkin@users.noreply.github.com> --- cpp/src/neighbors/detail/cagra/device_common.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 2c2c67fcd..e5886582d 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -155,7 +155,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( } else if ((traversed_hash_ptr != nullptr) && hashmap::search( traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { - // Deactivate this entry as it has been already used by otehrs. + // Deactivate this entry as it has been already used by others. best_norm2_team_local = raft::upper_bound(); best_index_team_local = raft::upper_bound(); } From 81e4b3904a55f39286b043e1c482c409cabaabf4 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Fri, 17 Jan 2025 15:10:41 +0900 Subject: [PATCH 12/12] Fixed data type issues --- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 8f22aecfb..63f5c51a6 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -268,10 +268,10 @@ void add_node_core( interleave_switch = 1 - interleave_switch; } if (num_add < degree) { - RAFT_FAIL("Number of edges is not enough (target_new_node_id:%u, num_add:%u, degree:%u)", - target_new_node_id, - num_add, - degree); + RAFT_FAIL("Number of edges is not enough (target_new_node_id:%lu, num_add:%lu, degree:%lu)", + (uint64_t)target_new_node_id, + (uint64_t)num_add, + (uint64_t)degree); } for (std::uint32_t i = 0; i < degree; i++) { updated_graph(target_new_node_id, i) = temp[i];