Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-CTA algorithm #492

Open
wants to merge 18 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
37 changes: 27 additions & 10 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by otehrs.
anaruse marked this conversation as resolved.
Show resolved Hide resolved
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}
Expand All @@ -168,7 +177,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
Expand All @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
87 changes: 74 additions & 13 deletions cpp/src/neighbors/detail/cagra/hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cstdint>

#define HASHMAP_LINEAR_PROBING

// #pragma GCC diagnostic push
// #pragma GCC diagnostic ignored
// #pragma GCC diagnostic pop
Expand All @@ -42,15 +44,15 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table,
}
}

template <class IdxT>
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
{
// Open addressing is used for collision resolution
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#if 1
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
Expand All @@ -59,32 +61,91 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
uint32_t index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT old = atomicCAS(&table[index], ~static_cast<IdxT>(0), key);
if (old == ~static_cast<IdxT>(0)) {
const IdxT old = atomicCAS(&table[index], hashval_empty, key);
if (old == hashval_empty) {
return 1;
} else if (old == key) {
return 0;
} else if (SUPPORT_REMOVE) {
// Checks if this key has been removed before.
const uint32_t old = atomicCAS(&table[index], removed_key, key);
if (old == removed_key) {
return 1;
} else if (old == key) {
return 0;
}
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <unsigned TEAM_SIZE, class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key)
{
IdxT ret = 0;
if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); }
for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) {
ret |= __shfl_xor_sync(0xffffffff, ret, offset);
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT val = table[index];
if (val == key) {
return 1;
} else if (val == hashval_empty) {
return 0;
} else if (SUPPORT_REMOVE) {
// Check if this key has been removed.
if (val == removed_key) {
return 0;
}
}
index = (index + stride) & bit_mask;
}
return ret;
return 0;
}

template <class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key)
{
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
// To remove a key, set the MSB to 1.
const uint32_t old = atomicCAS(&table[index], key, removed_key);
if (old == key) {
return 1;
} else if (old == hashval_empty) {
return 0;
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t
insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key)
{
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,24 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
search_params params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
: base_type(res, params, dataset_desc, dim, graph_degree, topk),
: base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk),
intermediate_indices(res),
intermediate_distances(res),
topk_workspace(res)

{
set_params(res, params);
}

void set_params(raft::resources const& res, const search_params& params)
{
constexpr unsigned muti_cta_itopk_size = 32;
this->itopk_size = muti_cta_itopk_size;
search_width = 1;
constexpr unsigned multi_cta_itopk_size = 32;
this->itopk_size = multi_cta_itopk_size;
search_width = 1;
num_cta_per_query =
max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size));
max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size));
result_buffer_size = itopk_size + search_width * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
Expand All @@ -128,7 +128,8 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_

smem_size = dataset_desc.smem_ws_size_in_bytes +
(sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 +
sizeof(uint32_t) * search_width + sizeof(uint32_t);
sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) +
sizeof(INDEX_T) * search_width;
RAFT_LOG_DEBUG("# smem_size: %u", smem_size);

//
Expand Down Expand Up @@ -222,6 +223,7 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
thread_block_size,
result_buffer_size,
smem_size,
small_hash_bitlen,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
uint32_t small_hash_bitlen, \
int64_t hash_bitlen, \
IndexT* hashmap_ptr, \
uint32_t num_cta_per_query, \
Expand Down
Loading
Loading