diff --git a/.gitmodules b/.gitmodules index 125572bcd..3d81e7d9d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "gperftools"] path = gperftools url = https://github.com/gperftools/gperftools.git +[submodule "CRoaring"] + path = CRoaring + url = https://github.com/RoaringBitmap/CRoaring diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d3d2b860..743890ca3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,6 +206,31 @@ endif() add_definitions(-DMKL_ILP64) + +# Roaring Bitmap +#if ( NOT EXISTS "${PROJECT_SOURCE_DIR}/CRoaring/LICENSE") +# message(FATAL_ERROR "The RoaringBitmap submodule was not found. " +# " Please run 'git submodule init' followed by 'git submodule update'") +#endif() +#add_subdirectory(CRoaring) +#include_directories(CRoaring/include/roaring) + +#set (CROARING_LIBRARY "${PROJECT_SOURCEDIR}/CRoaring/build/src/Release/roaring.lib") +#add_custom_target(build_croaring DEPENDS CROARING_LIBRARY) +#if (MSVC) +# add_custom_command(OUTPUT ${CROARING_LIBRARY} +# COMMAND ${CMAKE_VS_MSBUILD_COMMAND} build/RoaringBitmap.sln /m /nologo +# /t:roaring /p:Configuration="Release" +# /property:Platform="x64" +# /p:PlatformToolset=v${MSVC_TOOLSET_VERSION} +# /p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION} +# WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/CRoaring) +#endif() +#add_library(croaring_lib STATIC IMPORTED) +#set_target_properties(croaring_lib PROPERTIES IMPORTED_LOCATION "${CROARING_LIBRARY}") + + + # Section for tcmalloc. The DiskANN tools are always linked to tcmalloc. For Windows, they also need to # force-include the _tcmalloc symbol for enabling tcmalloc. # diff --git a/CRoaring b/CRoaring new file mode 160000 index 000000000..ad487efbb --- /dev/null +++ b/CRoaring @@ -0,0 +1 @@ +Subproject commit ad487efbb48637445e6a5eee10c81cc9fdd7bf76 diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..2da62fd67 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include +#include -#include "utils.h" #include "disk_utils.h" -#include "math_utils.h" #include "index.h" +#include "math_utils.h" #include "partition.h" #include "program_options_utils.hpp" +#include "utils.h" namespace po = boost::program_options; @@ -17,7 +17,7 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, label_type; - uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold, filter_bf_threshold; float B, M; bool append_reorder_data = false; bool use_opq = false; @@ -74,8 +74,9 @@ int main(int argc, char **argv) optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), program_options_utils::FILTERED_LBUILD); optional_configs.add_options()("filter_threshold,F", po::value(&filter_threshold)->default_value(0), - "Threshold to break up the existing nodes to generate new graph " - "internally where each node has a maximum F labels."); + program_options_utils::FILTER_THRESHOLD_DESCRIPTION); + optional_configs.add_options()("filter_bruteforce_threshold", po::value(&filter_bf_threshold)->default_value(0), + program_options_utils::FILTER_BRUTEFORCE_THRESHOLD_DESCRIPTION); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); @@ -139,6 +140,10 @@ int main(int argc, char **argv) std::string(std::to_string(append_reorder_data)) + " " + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); + if (filter_bf_threshold == 0) { + filter_bf_threshold = std::numeric_limits::max(); + } + try { if (label_file != "" && label_type == "ushort") @@ -146,15 +151,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("uint8")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("float")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, filter_bf_threshold); else { diskann::cerr << "Error. Unsupported data type" << std::endl; @@ -166,15 +171,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("uint8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("float")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else { diskann::cerr << "Error. Unsupported data type" << std::endl; diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 544e42dee..f0d469f4d 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include +#include +#include #include "index.h" -#include "utils.h" #include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include @@ -16,9 +16,9 @@ #include #endif -#include "memory_mapper.h" #include "ann_exception.h" #include "index_factory.h" +#include "memory_mapper.h" namespace po = boost::program_options; diff --git a/apps/build_stitched_index.cpp b/apps/build_stitched_index.cpp index 60e38c1be..0f385cb88 100644 --- a/apps/build_stitched_index.cpp +++ b/apps/build_stitched_index.cpp @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include "filter_utils.h" #include #include #include #include +#include #include #include #include -#include "filter_utils.h" -#include #ifndef _WINDOWS #include #endif @@ -17,8 +17,8 @@ #include "index.h" #include "memory_mapper.h" #include "parameters.h" -#include "utils.h" #include "program_options_utils.hpp" +#include "utils.h" namespace po = boost::program_options; typedef std::tuple>, uint64_t> stitch_indices_return_values; diff --git a/apps/range_search_disk_index.cpp b/apps/range_search_disk_index.cpp index 31675724b..3975298ae 100644 --- a/apps/range_search_disk_index.cpp +++ b/apps/range_search_disk_index.cpp @@ -2,26 +2,26 @@ // Licensed under the MIT license. #include +#include #include #include #include #include -#include -#include "index.h" #include "disk_utils.h" +#include "index.h" #include "math_utils.h" #include "memory_mapper.h" -#include "pq_flash_index.h" #include "partition.h" -#include "timer.h" +#include "pq_flash_index.h" #include "program_options_utils.hpp" +#include "timer.h" #ifndef _WINDOWS +#include "linux_aligned_file_reader.h" #include #include #include -#include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..925f31775 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -4,21 +4,21 @@ #include "common_includes.h" #include -#include "index.h" #include "disk_utils.h" +#include "index.h" #include "math_utils.h" #include "memory_mapper.h" #include "partition.h" -#include "pq_flash_index.h" -#include "timer.h" #include "percentile_stats.h" +#include "pq_flash_index.h" #include "program_options_utils.hpp" +#include "timer.h" #ifndef _WINDOWS +#include "linux_aligned_file_reader.h" #include #include #include -#include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" @@ -123,8 +123,8 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl; _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); // if (num_nodes_to_cache > 0) - // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache, - // num_threads, node_list); + // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, + // 15, 6, num_nodes_to_cache, num_threads, node_list); _pFlashIndex->load_cache_list(node_list); node_list.clear(); node_list.shrink_to_fit(); diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..9126ad1fc 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include -#include #include #include #include #include -#include #ifndef _WINDOWS #include @@ -18,10 +18,10 @@ #endif #include "index.h" +#include "index_factory.h" #include "memory_mapper.h" -#include "utils.h" #include "program_options_utils.hpp" -#include "index_factory.h" +#include "utils.h" namespace po = boost::program_options; @@ -323,9 +323,9 @@ int main(int argc, char **argv) optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "dynamic", po::value(&dynamic)->default_value(false), - "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); + optional_configs.add_options()("dynamic", po::value(&dynamic)->default_value(false), + "Whether the index is dynamic. Dynamic indices must have associated " + "tags. Default false."); optional_configs.add_options()("tags", po::value(&tags)->default_value(false), "Whether to search with external identifiers (tags). Default false."); optional_configs.add_options()("fail_if_recall_below", diff --git a/apps/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp index 97aed1864..21ce4250f 100644 --- a/apps/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -1,19 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include #include #include #include #include -#include -#include -#include "utils.h" #include "filter_utils.h" -#include "program_options_utils.hpp" #include "index_factory.h" +#include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index 5a43a69f3..d8ea0577c 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -1,20 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include +#include #include +#include #include #include #include #include #include -#include -#include -#include -#include -#include "utils.h" #include "filter_utils.h" #include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include diff --git a/apps/utils/bin_to_fvecs.cpp b/apps/utils/bin_to_fvecs.cpp index e9a6a8ecc..ebd8229ba 100644 --- a/apps/utils/bin_to_fvecs.cpp +++ b/apps/utils/bin_to_fvecs.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "util.h" +#include void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts, uint64_t ndims) diff --git a/apps/utils/bin_to_tsv.cpp b/apps/utils/bin_to_tsv.cpp index 7851bef6d..5c31c8595 100644 --- a/apps/utils/bin_to_tsv.cpp +++ b/apps/utils/bin_to_tsv.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include template void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims) diff --git a/apps/utils/calculate_recall.cpp b/apps/utils/calculate_recall.cpp index dc76252cc..3946bfdf2 100644 --- a/apps/utils/calculate_recall.cpp +++ b/apps/utils/calculate_recall.cpp @@ -9,8 +9,8 @@ #include #include -#include "utils.h" #include "disk_utils.h" +#include "utils.h" int main(int argc, char **argv) { diff --git a/apps/utils/compute_groundtruth.cpp b/apps/utils/compute_groundtruth.cpp index da32fd7c6..b86f28289 100644 --- a/apps/utils/compute_groundtruth.cpp +++ b/apps/utils/compute_groundtruth.cpp @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include #include +#include +#include +#include -#include #include +#include #include #include -#include -#include #include -#include -#include +#include #include -#include -#include +#include +#include +#include #include #include +#include +#include #ifdef _WINDOWS #include diff --git a/apps/utils/compute_groundtruth_for_filters.cpp b/apps/utils/compute_groundtruth_for_filters.cpp index 52e586475..e90da2444 100644 --- a/apps/utils/compute_groundtruth_for_filters.cpp +++ b/apps/utils/compute_groundtruth_for_filters.cpp @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include #include +#include +#include +#include -#include #include +#include #include #include -#include -#include #include -#include -#include +#include #include -#include -#include +#include +#include +#include #include #include +#include +#include #ifdef _WINDOWS #include diff --git a/apps/utils/count_bfs_levels.cpp b/apps/utils/count_bfs_levels.cpp index 6dd2d6233..6e45ef13d 100644 --- a/apps/utils/count_bfs_levels.cpp +++ b/apps/utils/count_bfs_levels.cpp @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include -#include #include #include #include #include -#include #ifndef _WINDOWS #include @@ -17,9 +17,9 @@ #include #endif -#include "utils.h" #include "index.h" #include "memory_mapper.h" +#include "utils.h" namespace po = boost::program_options; diff --git a/apps/utils/create_disk_layout.cpp b/apps/utils/create_disk_layout.cpp index f494c1227..6d5314fb4 100644 --- a/apps/utils/create_disk_layout.cpp +++ b/apps/utils/create_disk_layout.cpp @@ -8,9 +8,9 @@ #include #include -#include "utils.h" -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" +#include "utils.h" template int create_disk_layout(char **argv) { diff --git a/apps/utils/float_bin_to_int8.cpp b/apps/utils/float_bin_to_int8.cpp index 1982005af..c3fa8f8ec 100644 --- a/apps/utils/float_bin_to_int8.cpp +++ b/apps/utils/float_bin_to_int8.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts, size_t ndims, float bias, float scale) diff --git a/apps/utils/fvecs_to_bin.cpp b/apps/utils/fvecs_to_bin.cpp index 873ad3b0c..1428a9c6e 100644 --- a/apps/utils/fvecs_to_bin.cpp +++ b/apps/utils/fvecs_to_bin.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include // Convert float types void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts, diff --git a/apps/utils/fvecs_to_bvecs.cpp b/apps/utils/fvecs_to_bvecs.cpp index f9c2aa71b..60ac12126 100644 --- a/apps/utils/fvecs_to_bvecs.cpp +++ b/apps/utils/fvecs_to_bvecs.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts, size_t ndims) diff --git a/apps/utils/gen_random_slice.cpp b/apps/utils/gen_random_slice.cpp index a4cd96e0a..64bc994ef 100644 --- a/apps/utils/gen_random_slice.cpp +++ b/apps/utils/gen_random_slice.cpp @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include "partition.h" +#include "utils.h" #include #include #include @@ -10,10 +11,9 @@ #include #include #include +#include #include #include -#include "partition.h" -#include "utils.h" #include #include diff --git a/apps/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp index a881b1104..cff7a3526 100644 --- a/apps/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -2,8 +2,8 @@ // Licensed under the MIT license. #include "math_utils.h" -#include "pq.h" #include "partition.h" +#include "pq.h" #define KMEANS_ITERS_FOR_PQ 15 diff --git a/apps/utils/generate_synthetic_labels.cpp b/apps/utils/generate_synthetic_labels.cpp index 6741760cb..766c297d7 100644 --- a/apps/utils/generate_synthetic_labels.cpp +++ b/apps/utils/generate_synthetic_labels.cpp @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include +#include "utils.h" #include -#include #include -#include "utils.h" +#include +#include +#include namespace po = boost::program_options; class ZipfDistribution diff --git a/apps/utils/int8_to_float.cpp b/apps/utils/int8_to_float.cpp index dcdfddc0d..8277b9a09 100644 --- a/apps/utils/int8_to_float.cpp +++ b/apps/utils/int8_to_float.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include int main(int argc, char **argv) { diff --git a/apps/utils/int8_to_float_scale.cpp b/apps/utils/int8_to_float_scale.cpp index 19fbc6c43..757e79be1 100644 --- a/apps/utils/int8_to_float_scale.cpp +++ b/apps/utils/int8_to_float_scale.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts, size_t ndims, float bias, float scale) diff --git a/apps/utils/ivecs_to_bin.cpp b/apps/utils/ivecs_to_bin.cpp index ea8a4a3d2..854c06839 100644 --- a/apps/utils/ivecs_to_bin.cpp +++ b/apps/utils/ivecs_to_bin.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts, size_t ndims) diff --git a/apps/utils/merge_shards.cpp b/apps/utils/merge_shards.cpp index 106c15eef..be64e6ff9 100644 --- a/apps/utils/merge_shards.cpp +++ b/apps/utils/merge_shards.cpp @@ -10,8 +10,8 @@ #include #include -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" #include "utils.h" int main(int argc, char **argv) diff --git a/apps/utils/partition_data.cpp b/apps/utils/partition_data.cpp index 2520f3f4a..42c22d231 100644 --- a/apps/utils/partition_data.cpp +++ b/apps/utils/partition_data.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include "cached_io.h" #include "partition.h" +#include +#include // DEPRECATED: NEED TO REPROGRAM diff --git a/apps/utils/partition_with_ram_budget.cpp b/apps/utils/partition_with_ram_budget.cpp index 937b68d2c..c5b6ed596 100644 --- a/apps/utils/partition_with_ram_budget.cpp +++ b/apps/utils/partition_with_ram_budget.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include "cached_io.h" #include "partition.h" +#include +#include // DEPRECATED: NEED TO REPROGRAM diff --git a/apps/utils/rand_data_gen.cpp b/apps/utils/rand_data_gen.cpp index e89ede800..799aa0f33 100644 --- a/apps/utils/rand_data_gen.cpp +++ b/apps/utils/rand_data_gen.cpp @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include +#include #include +#include #include -#include -#include #include "utils.h" @@ -128,7 +128,8 @@ int main(int argc, char **argv) desc.add_options()("norm", po::value(&norm)->default_value(-1.0f), "Norm of the vectors (if not specified, vectors are not normalized)"); desc.add_options()("rand_scaling", po::value(&rand_scaling)->default_value(1.0f), - "Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from " + "Each vector will be scaled (if not explicitly normalized) by a factor " + "randomly chosen from " "[1, rand_scale]. Only applicable for floating point data"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -158,13 +159,17 @@ int main(int argc, char **argv) if (rand_scaling < 1.0) { - std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl; + std::cout << "We will only scale the vector norms randomly in [1, value], " + "so value must be >= 1." + << std::endl; return -1; } if ((rand_scaling > 1.0) && (normalization == true)) { - std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl; + std::cout << "Data cannot be normalized and randomly scaled at same time. " + "Use one or the other." + << std::endl; return -1; } diff --git a/apps/utils/simulate_aggregate_recall.cpp b/apps/utils/simulate_aggregate_recall.cpp index 73c4ea0f7..30cb24f13 100644 --- a/apps/utils/simulate_aggregate_recall.cpp +++ b/apps/utils/simulate_aggregate_recall.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include #include +#include #include -#include inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count, const std::vector &recalls) diff --git a/apps/utils/stats_label_data.cpp b/apps/utils/stats_label_data.cpp index 3342672ff..1fad04b61 100644 --- a/apps/utils/stats_label_data.cpp +++ b/apps/utils/stats_label_data.cpp @@ -1,28 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include +#include #include +#include #include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "utils.h" #ifndef _WINDOWS #include -#include #include #include +#include #else #include #endif diff --git a/apps/utils/tsv_to_bin.cpp b/apps/utils/tsv_to_bin.cpp index c590a8f73..9d52f70a2 100644 --- a/apps/utils/tsv_to_bin.cpp +++ b/apps/utils/tsv_to_bin.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) { diff --git a/apps/utils/uint32_to_uint8.cpp b/apps/utils/uint32_to_uint8.cpp index 87b6fb8ed..348dcaa20 100644 --- a/apps/utils/uint32_to_uint8.cpp +++ b/apps/utils/uint32_to_uint8.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include int main(int argc, char **argv) { diff --git a/apps/utils/uint8_to_float.cpp b/apps/utils/uint8_to_float.cpp index 6415b7c92..352aea00c 100644 --- a/apps/utils/uint8_to_float.cpp +++ b/apps/utils/uint8_to_float.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include int main(int argc, char **argv) { diff --git a/apps/utils/vector_analysis.cpp b/apps/utils/vector_analysis.cpp index 009df6d05..63364bc67 100644 --- a/apps/utils/vector_analysis.cpp +++ b/apps/utils/vector_analysis.cpp @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include #include #include #include #include +#include #include #include #include +#include #include #include -#include #include #include #include diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index 89856f1fa..44401a08f 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -3,12 +3,12 @@ #pragma once -#include #include +#include +#include "distance.h" #include "types.h" #include "windows_customizations.h" -#include "distance.h" namespace diskann { @@ -80,9 +80,10 @@ template class AbstractDataStore // num_points) to zero virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0; - // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to NOT take the query - // from the scratch object. Therefore every data store has to implement preprocess_query which - // at the least will be to copy the query into the scratch object. So making this pure virtual. + // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to + // NOT take the query from the scratch object. Therefore every data store has + // to implement preprocess_query which at the least will be to copy the query + // into the scratch object. So making this pure virtual. virtual void preprocess_query(const data_t *aligned_query, AbstractScratch *query_scratch = nullptr) const = 0; // distance functions. @@ -99,9 +100,9 @@ template class AbstractDataStore // in the dataset virtual location_t calculate_medoid() const = 0; - // REFACTOR PQ TODO: Each data store knows about its distance function, so this is - // redundant. However, we don't have an OptmizedDataStore yet, and to preserve code - // compability, we are exposing this function. + // REFACTOR PQ TODO: Each data store knows about its distance function, so + // this is redundant. However, we don't have an OptmizedDataStore yet, and to + // preserve code compability, we are exposing this function. virtual Distance *get_dist_fn() const = 0; // search helpers diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h new file mode 100644 index 000000000..858c6e283 --- /dev/null +++ b/include/abstract_filter_store.h @@ -0,0 +1,25 @@ +#pragma once +#include "types.h" +#include "windows_customizations.h" +#include + +namespace diskann +{ +template class AbstractFilterStore +{ + public: + DISKANN_DLLEXPORT virtual bool has_filter_support() const = 0; + + DISKANN_DLLEXPORT virtual bool point_has_label(location_t point_id, const LabelT label_id) const = 0; + + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT virtual bool load(const std::string &disk_index_file) = 0; + + DISKANN_DLLEXPORT virtual void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) = 0; +}; + +} // namespace diskann diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h index 4d6906ca4..961a4f994 100644 --- a/include/abstract_graph_store.h +++ b/include/abstract_graph_store.h @@ -3,9 +3,9 @@ #pragma once +#include "types.h" #include #include -#include "types.h" namespace diskann { diff --git a/include/abstract_index.h b/include/abstract_index.h index 059866f7c..9f0d402ef 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -1,10 +1,10 @@ #pragma once #include "distance.h" +#include "index_build_params.h" +#include "index_config.h" #include "parameters.h" -#include "utils.h" #include "types.h" -#include "index_config.h" -#include "index_build_params.h" +#include "utils.h" #include namespace diskann @@ -32,8 +32,9 @@ struct consolidation_report } }; -/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods -that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index. +/* A templated independent class for intercation with Index. Uses Type Erasure +to add virtual implemetation of methods that can take any type(using std::any) +and Provides a clean API that can be inherited by different type of Index. */ class AbstractIndex { diff --git a/include/aligned_file_reader.h b/include/aligned_file_reader.h index f39d5da39..447b34609 100644 --- a/include/aligned_file_reader.h +++ b/include/aligned_file_reader.h @@ -5,8 +5,8 @@ #define MAX_IO_DEPTH 128 -#include #include +#include #ifndef _WINDOWS #include @@ -63,12 +63,12 @@ struct IOContext #endif -#include +#include "tsl/robin_map.h" +#include "utils.h" #include +#include #include #include -#include "tsl/robin_map.h" -#include "utils.h" // NOTE :: all 3 fields must be 512-aligned struct AlignedRead diff --git a/include/ann_exception.h b/include/ann_exception.h index 6b81373c1..a9b940573 100644 --- a/include/ann_exception.h +++ b/include/ann_exception.h @@ -2,10 +2,10 @@ // Licensed under the MIT license. #pragma once -#include +#include "windows_customizations.h" #include +#include #include -#include "windows_customizations.h" #ifndef _WINDOWS #define __FUNCSIG__ __PRETTY_FUNCTION__ diff --git a/include/any_wrappers.h b/include/any_wrappers.h index da9005cfb..f35ac947c 100644 --- a/include/any_wrappers.h +++ b/include/any_wrappers.h @@ -3,11 +3,11 @@ #pragma once -#include +#include "tsl/robin_set.h" +#include #include +#include #include -#include -#include "tsl/robin_set.h" namespace AnyWrapper { diff --git a/include/cached_io.h b/include/cached_io.h index daef2f2f7..dabe448dc 100644 --- a/include/cached_io.h +++ b/include/cached_io.h @@ -2,13 +2,14 @@ // Licensed under the MIT license. #pragma once +#include #include #include #include #include -#include "logger.h" #include "ann_exception.h" +#include "logger.h" // sequential cached reads class cached_ifstream diff --git a/include/common_includes.h b/include/common_includes.h index e1a51bdec..c69098fbc 100644 --- a/include/common_includes.h +++ b/include/common_includes.h @@ -14,14 +14,15 @@ #include #include #include -#include #include +#include #include #include #include #include #include -#include #include +#include #include #include +#include diff --git a/include/cosine_similarity.h b/include/cosine_similarity.h index dc51f6c0a..af62eb53b 100644 --- a/include/cosine_similarity.h +++ b/include/cosine_similarity.h @@ -3,16 +3,16 @@ #pragma once -#include -#include -#include +#include #include #include #include -#include +#include #include -#include +#include #include +#include +#include #include "simd_utils.h" diff --git a/include/disk_utils.h b/include/disk_utils.h index 08f046dcd..df654c744 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -3,10 +3,10 @@ #pragma once #include -#include #include #include #include +#include #include #include #include @@ -96,9 +96,12 @@ DISKANN_DLLEXPORT int build_disk_index( diskann::Metric _compareMetric, bool use_opq = false, const std::string &codebook_prefix = "", // default is empty for no codebook pass in bool use_filters = false, - const std::string &label_file = std::string(""), // default is empty string for no label_file - const std::string &universal_label = "", const uint32_t filter_threshold = 0, - const uint32_t Lf = 0); // default is empty string for no universal label + const std::string &label_file = + std::string(""), // default is empty string for no label_file + const std::string &universal_label = "", + const uint32_t filter_threshold = 0, + const uint32_t Lf = 0, + const uint32_t filter_bf_threshold = 0); // default is empty string for no universal label template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, diff --git a/include/distance.h b/include/distance.h index f3b1de25a..7a3ec8b26 100644 --- a/include/distance.h +++ b/include/distance.h @@ -1,5 +1,6 @@ #pragma once #include "windows_customizations.h" +#include #include namespace diskann diff --git a/include/filter_brute_force_index.h b/include/filter_brute_force_index.h new file mode 100644 index 000000000..09db7bf50 --- /dev/null +++ b/include/filter_brute_force_index.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. +#pragma once +#include "common_includes.h" +#include "windows_customizations.h" +#include "filter_utils.h" + +namespace diskann { + + template + class FilterBruteForceIndex { + public : + DISKANN_DLLEXPORT FilterBruteForceIndex(const std::string& disk_index_file); + DISKANN_DLLEXPORT bool brute_force_index_available() const; + DISKANN_DLLEXPORT bool brute_forceable_filter(const std::string& filter) const; + DISKANN_DLLEXPORT int load(); + + private : + diskann::inverted_index_t _bf_filter_index; + bool _is_loaded; + std::string _disk_index_file; + }; +} \ No newline at end of file diff --git a/include/filter_utils.h b/include/filter_utils.h index 55f7aed28..472b8c49d 100644 --- a/include/filter_utils.h +++ b/include/filter_utils.h @@ -3,19 +3,20 @@ #pragma once #include -#include #include #include #include +#include #include #include #include #include #include -#include #include +#include #include #include +#include #ifdef __APPLE__ #else #include @@ -48,8 +49,15 @@ typedef std::tuple, tsl::robin_map parse_label_file_return_values; typedef std::tuple>, uint64_t> load_label_index_return_values; -namespace diskann -{ + +namespace diskann { + + //CONSTANTS + DISKANN_DLLEXPORT extern const char* NO_LABEL_FOR_POINT; + DISKANN_DLLEXPORT extern const char FILTERS_LABEL_DELIMITER; + typedef std::map> inverted_index_t; + + template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, unsigned R, unsigned L, float alpha, unsigned num_threads); @@ -68,6 +76,13 @@ DISKANN_DLLEXPORT tsl::robin_map> generate_la path input_data_path, tsl::robin_map labels_to_number_of_points, std::vector point_ids_to_labels, label_set all_labels); +template +DISKANN_DLLEXPORT void separate_brute_forceable_points( + const std::string& base_file, const std::string& label_file, + const location_t filter_bf_threshold, + const std::string& new_lbl_file, + const std::string& bf_data_file); + /* * For each label, generates a file containing all vectors that have said label. * Also copies data from original bin file to new dimension-aligned file. diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 0a0a617da..ad5dc8d7a 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -2,8 +2,8 @@ // Licensed under the MIT license. #pragma once -#include #include +#include #include "tsl/robin_map.h" #include "tsl/robin_set.h" @@ -12,10 +12,10 @@ #include "abstract_data_store.h" +#include "aligned_file_reader.h" #include "distance.h" #include "natural_number_map.h" #include "natural_number_set.h" -#include "aligned_file_reader.h" namespace diskann { diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h new file mode 100644 index 000000000..4915f37ee --- /dev/null +++ b/include/in_mem_filter_store.h @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "abstract_filter_store.h" +#include "ann_exception.h" +#include "logger.h" +#include "tsl/robin_map.h" +#include "tsl/robin_set.h" +#include "windows_customizations.h" +#include +#include +#include + +namespace diskann +{ +template class InMemFilterStore : public AbstractFilterStore +{ + public: + // Do nothing constructor because all the work is done in load() + DISKANN_DLLEXPORT InMemFilterStore() + { + } + + /// + /// Destructor + /// + DISKANN_DLLEXPORT virtual ~InMemFilterStore(); + + // No copy, no assignment. + DISKANN_DLLEXPORT InMemFilterStore &operator=(const InMemFilterStore &v) = delete; + DISKANN_DLLEXPORT + InMemFilterStore(const InMemFilterStore &v) = delete; + + DISKANN_DLLEXPORT virtual bool has_filter_support() const; + + DISKANN_DLLEXPORT virtual const std::unordered_map> &get_label_to_medoids() const; + + DISKANN_DLLEXPORT virtual const std::vector &get_medoids_of_label(const LabelT label); + + DISKANN_DLLEXPORT virtual void set_universal_label(const LabelT univ_label); + + DISKANN_DLLEXPORT inline bool point_has_label(location_t point_id, const LabelT label_id) const + { + uint32_t start_vec = _pts_to_label_offsets[point_id]; + uint32_t num_lbls = _pts_to_label_counts[point_id]; + bool ret_val = false; + for (uint32_t i = 0; i < num_lbls; i++) + { + if (_pts_to_labels[start_vec + i] == label_id) + { + ret_val = true; + break; + } + } + return ret_val; + } + + DISKANN_DLLEXPORT inline bool is_dummy_point(location_t id) const + { + return _dummy_pts.find(id) != _dummy_pts.end(); + } + + DISKANN_DLLEXPORT inline location_t get_real_point_for_dummy(location_t dummy_id) + { + if (is_dummy_point(dummy_id)) + { + return _dummy_to_real_map[dummy_id]; + } + else + { + return dummy_id; // it is a real point. + } + } + + DISKANN_DLLEXPORT inline bool point_has_label_or_universal_label(location_t id, const LabelT filter_label) const + { + return point_has_label(id, filter_label) || + (_use_universal_label && point_has_label(id, _universal_filter_label)); + } + + DISKANN_DLLEXPORT inline LabelT get_converted_label(const std::string &filter_label) + { + if (_label_map.find(filter_label) != _label_map.end()) + { + return _label_map[filter_label]; + } + if (_use_universal_label) + { + return _universal_filter_label; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT bool load(const std::string &disk_index_file); + + DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads); + + private: + // Load functions for search START + void load_label_file(const std::string_view &file_content); + void load_label_map(std::basic_istream &map_reader); + void load_labels_to_medoids(std::basic_istream &reader); + void load_dummy_map(std::basic_istream &dummy_map_stream); + void parse_universal_label(const std::string_view &content); + void get_label_file_metadata(const std::string_view &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); + + bool load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)); + bool parse_stream(const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)); + + void reset_stream_for_reading(std::basic_istream &infile); + // Load functions for search END + + location_t _num_points = 0; + location_t *_pts_to_label_offsets = nullptr; + location_t *_pts_to_label_counts = nullptr; + LabelT *_pts_to_labels = nullptr; + bool _use_universal_label = false; + LabelT _universal_filter_label; + tsl::robin_set _dummy_pts; + tsl::robin_set _has_dummy_pts; + tsl::robin_map _dummy_to_real_map; + tsl::robin_map> _real_to_dummy_map; + std::unordered_map _label_map; + std::unordered_map> _filter_to_medoid_ids; + bool _is_valid = false; +}; + +} // namespace diskann diff --git a/include/index.h b/include/index.h index b9bf4f384..19694c785 100644 --- a/include/index.h +++ b/include/index.h @@ -9,21 +9,21 @@ #include "aligned_file_reader.h" #endif +#include "abstract_index.h" #include "distance.h" +#include "in_mem_data_store.h" +#include "in_mem_graph_store.h" #include "locking.h" #include "natural_number_map.h" #include "natural_number_set.h" #include "neighbor.h" #include "parameters.h" +#include "scratch.h" #include "utils.h" #include "windows_customizations.h" -#include "scratch.h" -#include "in_mem_data_store.h" -#include "in_mem_graph_store.h" -#include "abstract_index.h" -#include "quantized_distance.h" #include "pq_data_store.h" +#include "quantized_distance.h" #define OVERHEAD_FACTOR 1.1 #define EXPAND_IF_FULL 0 @@ -55,18 +55,19 @@ template clas public: // Constructor for Bulk operations and for creating the index object solely // for loading a prexisting index. - DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::shared_ptr> data_store, - std::unique_ptr graph_store, - std::shared_ptr> pq_data_store = nullptr); + DISKANN_DLLEXPORT + Index(const IndexConfig &index_config, std::shared_ptr> data_store, + std::unique_ptr graph_store, + std::shared_ptr> pq_data_store = nullptr); // Constructor for incremental index - DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points, - const std::shared_ptr index_parameters, - const std::shared_ptr index_search_params, - const size_t num_frozen_pts = 0, const bool dynamic_index = false, - const bool enable_tags = false, const bool concurrent_consolidate = false, - const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false, const bool filtered_index = false); + DISKANN_DLLEXPORT + Index(Metric m, const size_t dim, const size_t max_points, + const std::shared_ptr index_parameters, + const std::shared_ptr index_search_params, const size_t num_frozen_pts = 0, + const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false, + const bool pq_dist_build = false, const size_t num_pq_chunks = 0, const bool use_opq = false, + const bool filtered_index = false); DISKANN_DLLEXPORT ~Index(); @@ -376,18 +377,19 @@ template clas // Filter Support bool _filtered_index = false; - // Location to label is only updated during insert_point(), all other reads are protected by - // default as a location can only be released at end of consolidate deletes + // Location to label is only updated during insert_point(), all other reads + // are protected by default as a location can only be released at end of + // consolidate deletes std::vector> _location_to_labels; tsl::robin_set _labels; std::string _labels_file; std::unordered_map _label_to_start_id; std::unordered_map _medoid_counts; - bool _use_universal_label = false; - LabelT _universal_label = 0; - uint32_t _filterIndexingQueueSize; - std::unordered_map _label_map; + bool _use_universal_label = false; + LabelT _universal_label = 0; + uint32_t _filter_indexing_queue_size; + std::unordered_map _label_map; // Indexing parameters uint32_t _indexingQueueSize; @@ -436,7 +438,8 @@ template clas std::shared_timed_mutex // Ensure only one consolidate or compact_data is _consolidate_lock; // ever active std::shared_timed_mutex // RW lock for _tag_to_location, - _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, _label_to_start_id + _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, + // _label_to_start_id std::shared_timed_mutex // RW Lock on _delete_set and _data_compacted _delete_lock; // variable diff --git a/include/index_build_params.h b/include/index_build_params.h index d4f454830..38434e204 100644 --- a/include/index_build_params.h +++ b/include/index_build_params.h @@ -1,5 +1,6 @@ #pragma once +#include "ann_exception.h" #include "common_includes.h" #include "parameters.h" @@ -32,7 +33,7 @@ class IndexFilterParamsBuilder IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) { if (save_path_prefix.empty() || save_path_prefix == "") - throw ANNException("Error: save_path_prefix can't be empty", -1); + throw diskann::ANNException("Error: save_path_prefix can't be empty", -1); this->_save_path_prefix = save_path_prefix; return *this; } diff --git a/include/index_config.h b/include/index_config.h index a8e64d08c..d1709cc1e 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -1,7 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + #pragma once +#include "ann_exception.h" #include "common_includes.h" +#include "logger.h" #include "parameters.h" +#include namespace diskann { @@ -209,13 +215,17 @@ class IndexConfigBuilder if (_dynamic_index) { if (_index_search_params != nullptr && _index_search_params->initial_search_list_size == 0) - throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1); + throw ANNException("Error: please pass initial_search_list_size for " + "building dynamic index.", + -1); } // sanity check if (_dynamic_index && _num_frozen_pts == 0) { - diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting it to 1 for safety." << std::endl; + diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting " + "it to 1 for safety." + << std::endl; _num_frozen_pts = 1; } diff --git a/include/index_factory.h b/include/index_factory.h index 76fb0b978..a41c1f50f 100644 --- a/include/index_factory.h +++ b/include/index_factory.h @@ -1,8 +1,8 @@ #pragma once -#include "index.h" #include "abstract_graph_store.h" #include "in_mem_graph_store.h" +#include "index.h" #include "pq_data_store.h" namespace diskann @@ -20,9 +20,9 @@ class IndexFactory DISKANN_DLLEXPORT static std::shared_ptr> construct_datastore(DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); - // For now PQDataStore incorporates within itself all variants of quantization that we support. In the - // future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization - // flavours. + // For now PQDataStore incorporates within itself all variants of quantization + // that we support. In the future it may be necessary to introduce an + // AbstractPQDataStore class to spearate various quantization flavours. template DISKANN_DLLEXPORT static std::shared_ptr> construct_pq_datastore(DataStoreStrategy strategy, size_t num_points, size_t dimension, diff --git a/include/logger.h b/include/logger.h index 0b17807db..f1c6ee7f3 100644 --- a/include/logger.h +++ b/include/logger.h @@ -2,9 +2,9 @@ // Licensed under the MIT license. #pragma once +#include "windows_customizations.h" #include #include -#include "windows_customizations.h" #ifdef EXEC_ENV_OLS #ifndef ENABLE_CUSTOM_LOGGER diff --git a/include/logger_impl.h b/include/logger_impl.h index 03c65e0ce..d2dfaf573 100644 --- a/include/logger_impl.h +++ b/include/logger_impl.h @@ -3,8 +3,8 @@ #pragma once -#include #include +#include #include "ann_exception.h" #include "logger.h" diff --git a/include/neighbor.h b/include/neighbor.h index d7c0c25ed..61a6932c1 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -3,10 +3,10 @@ #pragma once +#include "utils.h" #include #include #include -#include "utils.h" namespace diskann { diff --git a/include/parameters.h b/include/parameters.h index 0206814bd..50e7e4a1a 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -6,8 +6,8 @@ #include #include -#include "omp.h" #include "defaults.h" +#include "omp.h" namespace diskann { diff --git a/include/pq.h b/include/pq.h index 3e6119f22..db9226d8b 100644 --- a/include/pq.h +++ b/include/pq.h @@ -3,8 +3,8 @@ #pragma once -#include "utils.h" #include "pq_common.h" +#include "utils.h" namespace diskann { diff --git a/include/pq_common.h b/include/pq_common.h index c6a3a5739..d7a4b60f4 100644 --- a/include/pq_common.h +++ b/include/pq_common.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #define NUM_PQ_BITS 8 #define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS) diff --git a/include/pq_data_store.h b/include/pq_data_store.h index 7c0cb5fe0..4e223e785 100644 --- a/include/pq_data_store.h +++ b/include/pq_data_store.h @@ -1,14 +1,15 @@ #pragma once -#include +#include "abstract_data_store.h" #include "distance.h" -#include "quantized_distance.h" #include "pq.h" -#include "abstract_data_store.h" +#include "quantized_distance.h" +#include namespace diskann { -// REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because both Vamana and -// DiskANN treat it the same way. But with DiskPQ, that may need to change. +// REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because +// both Vamana and DiskANN treat it the same way. But with DiskPQ, that may need +// to change. template class PQDataStore : public AbstractDataStore { @@ -30,8 +31,8 @@ template class PQDataStore : public AbstractDataStore // vectors file. virtual size_t save(const std::string &file_prefix, const location_t num_points) override; - // Since base class function is pure virtual, we need to declare it here, even though alignent concept is not needed - // for Quantized data stores. + // Since base class function is pure virtual, we need to declare it here, even + // though alignent concept is not needed for Quantized data stores. virtual size_t get_aligned_dim() const override; // Populate quantized data from unaligned data using PQ functionality diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index ba5258e18..9e43debfc 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -10,11 +10,13 @@ #include "parameters.h" #include "percentile_stats.h" #include "pq.h" -#include "utils.h" -#include "windows_customizations.h" #include "scratch.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" +#include "utils.h" +#include "windows_customizations.h" + +#include "in_mem_filter_store.h" #define FULL_PRECISION_REORDER_MULTIPLIER 3 @@ -81,8 +83,6 @@ template class PQFlashIndex const uint32_t io_limit, const bool use_reorder_data = false, QueryStats *stats = nullptr); - DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); - DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search, const uint64_t max_l_search, std::vector &indices, std::vector &distances, const uint64_t min_beam_width, @@ -90,16 +90,20 @@ template class PQFlashIndex DISKANN_DLLEXPORT uint64_t get_data_dim(); + DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); + std::shared_ptr &reader; DISKANN_DLLEXPORT diskann::Metric get_metric(); // // node_ids: input list of node_ids to be read - // coord_buffers: pointers to pre-allocated buffers that coords need to copied to. If null, dont copy. - // nbr_buffers: pre-allocated buffers to copy neighbors into + // coord_buffers: pointers to pre-allocated buffers that coords need to copied + // to. If null, dont copy. nbr_buffers: pre-allocated buffers to copy + // neighbors into // - // returns a vector of bool one for each node_id: true if read is success, else false + // returns a vector of bool one for each node_id: true if read is success, + // else false // DISKANN_DLLEXPORT std::vector read_nodes(const std::vector &node_ids, std::vector &coord_buffers, @@ -112,18 +116,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); - DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); - private: - DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); - std::unordered_map load_label_map(std::basic_istream &infile); - DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels); - DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads); - void reset_stream_for_reading(std::basic_istream &infile); - // sector # on disk where node_id is present with in the graph part DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id); @@ -141,8 +134,8 @@ template class PQFlashIndex // offset in sector: [(i % nnodes_per_sector) * max_node_len] // // index info for multi-sector nodes - // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, SECTOR_LEN)] - // offset in sector: [0] + // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, + // SECTOR_LEN)] offset in sector: [0] // // Common info // coords start at ofsset @@ -221,18 +214,10 @@ template class PQFlashIndex bool _reorder_data_exists = false; uint64_t _reoreder_data_offset = 0; - // filter support - uint32_t *_pts_to_label_offsets = nullptr; - uint32_t *_pts_to_label_counts = nullptr; - LabelT *_pts_to_labels = nullptr; - std::unordered_map> _filter_to_medoid_ids; - bool _use_universal_label = false; - LabelT _universal_filter_label; - tsl::robin_set _dummy_pts; - tsl::robin_set _has_dummy_pts; - tsl::robin_map _dummy_to_real_map; - tsl::robin_map> _real_to_dummy_map; - std::unordered_map _label_map; + // Moved filter-specific data structures to in_mem_filter_store. + // TODO: Make this a unique pointer + bool _filter_index = false; + std::unique_ptr> _filter_store; #ifdef EXEC_ENV_OLS // Set to a larger value than the actual header to accommodate diff --git a/include/pq_scratch.h b/include/pq_scratch.h index 95f1b1395..6b52463eb 100644 --- a/include/pq_scratch.h +++ b/include/pq_scratch.h @@ -1,7 +1,7 @@ #pragma once -#include #include "pq_common.h" #include "utils.h" +#include namespace diskann { diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 2be60595b..965a7dc5b 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -43,6 +43,11 @@ const char *FILTERS_FILE_DESCRIPTION = const char *LABEL_TYPE_DESCRIPTION = "Storage type of Labels {uint/uint32, ushort/uint16}, default value is uint which will consume memory 4 bytes per " "filter. 'uint' is an alias for 'uint32' and 'ushort' is an alias for 'uint16'."; +const char* FILTER_THRESHOLD_DESCRIPTION = "Threshold to break up the existing nodes to generate new graph " +"internally where each node has a maximum F labels."; +const char* FILTER_BRUTEFORCE_THRESHOLD_DESCRIPTION = "Use brute force for searching with a filter if it occurs" +" fewer than this many times in the dataset."; + const char *GROUND_TRUTH_FILE_DESCRIPTION = "ground truth file for the queryset"; // what's the format, what's the requirements? does it need to include an // entry for every item or just a small subset? I have so many questions about diff --git a/include/quantized_distance.h b/include/quantized_distance.h index cc4aea929..44798ac96 100644 --- a/include/quantized_distance.h +++ b/include/quantized_distance.h @@ -1,8 +1,8 @@ #pragma once +#include "abstract_scratch.h" #include #include #include -#include "abstract_scratch.h" namespace diskann { @@ -48,9 +48,10 @@ template class QuantizedDistance virtual void preprocessed_distance(PQScratch &pq_scratch, const uint32_t n_ids, std::vector &dists_out) = 0; - // Currently this function is required for DiskPQ. However, it too can be subsumed - // under preprocessed_distance if we add the appropriate scratch variables to - // PQScratch and initialize them in pq_flash_index.cpp::disk_iterate_to_fixed_point() + // Currently this function is required for DiskPQ. However, it too can be + // subsumed under preprocessed_distance if we add the appropriate scratch + // variables to PQScratch and initialize them in + // pq_flash_index.cpp::disk_iterate_to_fixed_point() virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) = 0; }; } // namespace diskann diff --git a/include/restapi/search_wrapper.h b/include/restapi/search_wrapper.h index ebd067d8a..d41b2b7cd 100644 --- a/include/restapi/search_wrapper.h +++ b/include/restapi/search_wrapper.h @@ -3,9 +3,9 @@ #pragma once +#include #include #include -#include #include #include diff --git a/include/restapi/server.h b/include/restapi/server.h index 1d75847a2..ddb19d17a 100644 --- a/include/restapi/server.h +++ b/include/restapi/server.h @@ -3,8 +3,8 @@ #pragma once -#include #include +#include namespace diskann { diff --git a/include/scratch.h b/include/scratch.h index 2f43e3365..79bb027de 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -7,15 +7,15 @@ #include "boost_dynamic_bitset_fwd.h" // #include "boost/dynamic_bitset.hpp" -#include "tsl/robin_set.h" #include "tsl/robin_map.h" +#include "tsl/robin_set.h" #include "tsl/sparse_map.h" -#include "aligned_file_reader.h" #include "abstract_scratch.h" -#include "neighbor.h" -#include "defaults.h" +#include "aligned_file_reader.h" #include "concurrent_queue.h" +#include "defaults.h" +#include "neighbor.h" namespace diskann { diff --git a/include/simd_utils.h b/include/simd_utils.h index 4b0736998..da59c0cde 100644 --- a/include/simd_utils.h +++ b/include/simd_utils.h @@ -2,9 +2,9 @@ #ifdef _WINDOWS #include +#include #include #include -#include #else #include #endif diff --git a/include/types.h b/include/types.h index 953d59a5f..58d8d40a4 100644 --- a/include/types.h +++ b/include/types.h @@ -3,10 +3,10 @@ #pragma once -#include -#include -#include #include "any_wrappers.h" +#include +#include +#include namespace diskann { diff --git a/include/utils.h b/include/utils.h index d3af5c3a9..39589e5eb 100644 --- a/include/utils.h +++ b/include/utils.h @@ -20,14 +20,14 @@ typedef HANDLE FileHandle; typedef int FileHandle; #endif +#include "ann_exception.h" +#include "cached_io.h" #include "distance.h" #include "logger.h" -#include "cached_io.h" -#include "ann_exception.h" -#include "windows_customizations.h" +#include "tag_uint128.h" #include "tsl/robin_set.h" #include "types.h" -#include "tag_uint128.h" +#include "windows_customizations.h" #include #ifdef EXEC_ENV_OLS @@ -177,6 +177,33 @@ inline int delete_file(const std::string &fileName) } } +inline std::string trim(const std::string& str) { + // Find the first non-whitespace character + size_t start = 0; + while (start < str.size() && std::isspace(str[start])) { + ++start; + } + + // Find the last non-whitespace character + size_t end = str.size(); + while (end > start && std::isspace(str[end - 1])) { + --end; + } + + // Return the substring from start to end + return str.substr(start, end - start); +} + +//VERY INEFFICIENT SPLIT_FUNCTION. USE IT AT YOUR OWN RISK. +inline void split_string(const std::string& str, const char sep, std::vector& split_strings) { + std::string token; + std::istringstream iss(str); + while (getline(iss, token, sep)) { + token = trim(token); + split_strings.push_back(token); + } +} + // generates formatted_label and _labels_map file. inline void convert_labels_string_to_int(const std::string &inFileName, const std::string &outFileName, const std::string &mapFileName, const std::string &unv_label) @@ -1188,8 +1215,8 @@ template <> inline const char *diskann_type_to_name() } #ifdef _WINDOWS -#include #include +#include extern bool AvxSupportedCPU; extern bool Avx2SupportedCPU; diff --git a/include/windows_aligned_file_reader.h b/include/windows_aligned_file_reader.h index 0d9a3173c..e3a898b9a 100644 --- a/include/windows_aligned_file_reader.h +++ b/include/windows_aligned_file_reader.h @@ -9,13 +9,13 @@ #include #include -#include -#include -#include #include "aligned_file_reader.h" #include "tsl/robin_map.h" #include "utils.h" #include "windows_customizations.h" +#include +#include +#include class WindowsAlignedFileReader : public AlignedFileReader { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cbca26440..80b11754f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,9 +11,10 @@ else() set(CPP_SOURCES abstract_data_store.cpp ann_exception.cpp disk_utils.cpp distance.cpp index.cpp in_mem_graph_store.cpp in_mem_data_store.cpp linux_aligned_file_reader.cpp math_utils.cpp natural_number_map.cpp - in_mem_data_store.cpp in_mem_graph_store.cpp + in_mem_data_store.cpp in_mem_graph_store.cpp in_mem_filter_store.cpp natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp - pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp) + pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp + index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp) if (RESTAPI) list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp) endif() diff --git a/src/abstract_data_store.cpp b/src/abstract_data_store.cpp index 0cff0152e..79efaca45 100644 --- a/src/abstract_data_store.cpp +++ b/src/abstract_data_store.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "abstract_data_store.h" +#include namespace diskann { diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 92665825f..c3c257ba2 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -1,6 +1,6 @@ +#include "abstract_index.h" #include "common_includes.h" #include "windows_customizations.h" -#include "abstract_index.h" namespace diskann { diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 22f1e98fd..5a1fb5e6b 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -7,17 +7,19 @@ #include "gperftools/malloc_extension.h" #endif -#include "logger.h" -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" #include "index.h" +#include "logger.h" #include "mkl.h" #include "omp.h" -#include "percentile_stats.h" #include "partition.h" +#include "percentile_stats.h" #include "pq_flash_index.h" #include "timer.h" #include "tsl/robin_set.h" +#include "utils.h" +#include "filter_utils.h" namespace diskann { @@ -1097,43 +1099,48 @@ void create_disk_layout(const std::string base_file, const std::string mem_index diskann::cout << "Output disk index file written to " << output_file << std::endl; } + + template -int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, - const uint32_t Lf) -{ - std::stringstream parser; - parser << std::string(indexBuildParameters); - std::string cur_param; - std::vector param_list; - while (parser >> cur_param) - { - param_list.push_back(cur_param); - } - if (param_list.size() < 5 || param_list.size() > 9) - { - diskann::cout << "Correct usage of parameters is R (max degree)\n" - "L (indexing list size, better if >= R)\n" - "B (RAM limit of final index in GB)\n" - "M (memory limit while indexing)\n" - "T (number of threads for indexing)\n" - "B' (PQ bytes for disk index: optional parameter for " - "very large dimensional data)\n" - "reorder (set true to include full precision in data file" - ": optional paramter, use only when using disk PQ\n" - "build_PQ_byte (number of PQ bytes for inde build; set 0 to use " - "full precision vectors)\n" - "QD Quantized Dimension to overwrite the derived dim from B " - << std::endl; - return -1; - } +int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold) { + std::stringstream parser; + parser << std::string(indexBuildParameters); + std::string cur_param; + std::vector param_list; + while (parser >> cur_param) { + param_list.push_back(cur_param); + } + if (param_list.size() < 5 || param_list.size() > 10) { + diskann::cout + << "Correct usage of parameters is R (max degree)\n" + "L (indexing list size, better if >= R)\n" + "B (RAM limit of final index in GB)\n" + "M (memory limit while indexing)\n" + "T (number of threads for indexing)\n" + "B' (PQ bytes for disk index: optional parameter for " + "very large dimensional data)\n" + "reorder (set true to include full precision in data file" + ": optional paramter, use only when using disk PQ\n" + "build_PQ_byte (number of PQ bytes for inde build; set 0 to use " + "full precision vectors)\n" + "QD Quantized Dimension to overwrite the derived dim from B " + << std::endl; + return -1; + } if (!std::is_same::value && (compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE)) { std::stringstream stream; - stream << "Disk-index build currently only supports floating point data for Max " + stream << "Disk-index build currently only supports floating point data " + "for Max " "Inner Product Search/ cosine similarity. " << std::endl; throw diskann::ANNException(stream.str(), -1); @@ -1182,55 +1189,74 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::string medoids_path = disk_index_path + "_medoids.bin"; std::string centroids_path = disk_index_path + "_centroids.bin"; - std::string labels_to_medoids_path = disk_index_path + "_labels_to_medoids.txt"; - std::string mem_labels_file = mem_index_path + "_labels.txt"; - std::string disk_labels_file = disk_index_path + "_labels.txt"; - std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; - std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; - std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; - std::string dummy_remap_file = disk_index_path + "_dummy_map.txt"; // remap will be used if we break-up points of - // high label-density to create copies + std::string labels_to_medoids_path = + disk_index_path + "_labels_to_medoids.txt"; + std::string mem_labels_file = mem_index_path + "_labels.txt"; + std::string disk_labels_file = disk_index_path + "_labels.txt"; + std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; + std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; + std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; + std::string dummy_remap_file = + disk_index_path + + "_dummy_map.txt"; // remap will be used if we break-up points of + // high label-density to create copies + std::string bf_data_file = disk_index_path + "_brute_force.txt"; + std::string bf_excluded_label_file = disk_index_path + "_non_brute_force.txt"; std::string sample_base_prefix = index_prefix_path + "_sample"; // optional, used if disk index file must store pq data std::string disk_pq_pivots_path = index_prefix_path + "_disk.index_pq_pivots.bin"; // optional, used if disk index must store pq data std::string disk_pq_compressed_vectors_path = index_prefix_path + "_disk.index_pq_compressed.bin"; - std::string prepped_base = - index_prefix_path + - "_prepped_base.bin"; // temp file for storing pre-processed base file for cosine/ mips metrics + std::string prepped_base = index_prefix_path + "_prepped_base.bin"; // temp file for storing pre-processed base file + // for cosine/ mips metrics bool created_temp_file_for_processed_data = false; - // output a new base file which contains extra dimension with sqrt(1 - - // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on - // disk needed! - if (compareMetric == diskann::Metric::INNER_PRODUCT) - { - Timer timer; - std::cout << "Using Inner Product search, so need to pre-process base " - "data into temp file. Please ensure there is additional " - "(n*(d+1)*4) bytes for storing pre-processed base vectors, " - "apart from the interim indices created by DiskANN and the final index." + + //Brute force check. + if (use_filters && filter_bf_threshold > 0) { + assert(label_file != ""); + diskann::separate_brute_forceable_points(data_file_to_use, label_file, + filter_bf_threshold, bf_excluded_label_file, bf_data_file); + labels_file_original = bf_excluded_label_file; + created_temp_file_for_processed_data = true; + } + + // output a new base file which contains extra dimension with sqrt(1 - + // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on + // disk needed! + if (compareMetric == diskann::Metric::INNER_PRODUCT) { + Timer timer; + std::cout << "Using Inner Product search, so need to pre-process base " + "data into temp file. Please ensure there is additional " + "(n*(d+1)*4) bytes for storing pre-processed base vectors, " + "apart from the interim indices created by DiskANN and the " + "final index." + << std::endl; + data_file_to_use = prepped_base; + float max_norm_of_base = + diskann::prepare_base_for_inner_products(base_file, prepped_base); + std::string norm_file = disk_index_path + "_max_base_norm.bin"; + diskann::save_bin(norm_file, &max_norm_of_base, 1, 1); + diskann::cout << timer.elapsed_seconds_for_step( + "preprocessing data for inner product") << std::endl; - data_file_to_use = prepped_base; - float max_norm_of_base = diskann::prepare_base_for_inner_products(base_file, prepped_base); - std::string norm_file = disk_index_path + "_max_base_norm.bin"; - diskann::save_bin(norm_file, &max_norm_of_base, 1, 1); - diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for inner product") << std::endl; - created_temp_file_for_processed_data = true; - } - else if (compareMetric == diskann::Metric::COSINE) - { - Timer timer; - std::cout << "Normalizing data for cosine to temporary file, please ensure there is additional " - "(n*d*4) bytes for storing normalized base vectors, " - "apart from the interim indices created by DiskANN and the final index." + created_temp_file_for_processed_data = true; + } else if (compareMetric == diskann::Metric::COSINE) { + Timer timer; + std::cout << "Normalizing data for cosine to temporary file, please ensure " + "there is additional " + "(n*d*4) bytes for storing normalized base vectors, " + "apart from the interim indices created by DiskANN and the " + "final index." + << std::endl; + data_file_to_use = prepped_base; + diskann::normalize_data_file(base_file, prepped_base); + diskann::cout << timer.elapsed_seconds_for_step( + "preprocessing data for cosine") << std::endl; - data_file_to_use = prepped_base; - diskann::normalize_data_file(base_file, prepped_base); - diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for cosine") << std::endl; - created_temp_file_for_processed_data = true; - } + created_temp_file_for_processed_data = true; + } uint32_t R = (uint32_t)atoi(param_list[0].c_str()); uint32_t L = (uint32_t)atoi(param_list[1].c_str()); @@ -1320,7 +1346,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) MallocExtension::instance()->ReleaseFreeMemory(); #endif - // Whether it is cosine or inner product, we still L2 metric due to the pre-processing. + // Whether it is cosine or inner product, we still L2 metric due to the + // pre-processing. timer.reset(); diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, @@ -1427,49 +1454,49 @@ template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, float *tuning_sample, uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); -template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold); // LabelT = uint16 -template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf, + const uint32_t filter_bf_threshold); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, diff --git a/src/distance.cpp b/src/distance.cpp index c2f88c85b..36228afe7 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -1,11 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. // TODO // CHECK COSINE ON LINUX #ifdef _WINDOWS #include +#include #include #include -#include #else #include #endif @@ -14,10 +16,10 @@ #include #include +#include "ann_exception.h" #include "distance.h" -#include "utils.h" #include "logger.h" -#include "ann_exception.h" +#include "utils.h" namespace diskann { diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 096d1b76e..7fc6522ae 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -4,7 +4,8 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp + ../in_mem_filter_store.cpp ../filter_brute_force_index.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") @@ -20,6 +21,7 @@ target_include_directories(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTO target_link_options(${PROJECT_NAME} PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB} /LTCG) target_link_libraries(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) target_link_libraries(${PROJECT_NAME} PRIVATE synchronization.lib) +target_link_libraries(${PROJECT_NAME} PRIVATE ${CROARING_LIBRARY}) if (DISKANN_DLL_TCMALLOC_LINK_OPTIONS) target_link_libraries(${PROJECT_NAME} PUBLIC ${DISKANN_DLL_TCMALLOC_LINK_OPTIONS}) diff --git a/src/filter_brute_force_index.cpp b/src/filter_brute_force_index.cpp new file mode 100644 index 000000000..0ba888433 --- /dev/null +++ b/src/filter_brute_force_index.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "filter_brute_force_index.h" + +namespace diskann { + + template + FilterBruteForceIndex::FilterBruteForceIndex(const std::string& disk_index_file) { + _disk_index_file = disk_index_file; + _filter_bf_data_file = _disk_index_file + "_brute_force.txt"; + } + template + bool FilterBruteForceIndex::brute_force_index_available() const {} + + template + bool FilterBruteForceIndex::brute_forceable_filter(const std::string& filter) const {} + + template + int FilterBruteForceIndex::load() { + if (false == file_exists(_filter_bf_data_file)) { + diskann::cerr << "Index does not have brute force support." << std::endl; + return 1; + } + std::ifstream bf_in(_filter_bf_data_file); + if (!bf_in.is_open()) { + std::stringstream ss; + ss << "Could not open " << _filter_bf_data_file << " for reading. " << std::endl; + diskann::cerr << ss.str() << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + std::string line; + std::vector label_and_points; + label_and_points.reserve(2); + std::unordered_set points; + + size_t linenum = 0; + while (getline(bf_in, line)) { + split_string(line, '\t', label_and_points); + if (label_and_points.size() == 2) { + + std::istringstream iss(label_and_points[1]); + std::string pt_str; + while (getline(iss, pt_str, ',')) { + points.insert(strtoul(pt_str)); + } + assert(points.size() > 0); + _bf_filter_index.insert(label_and_points[0], points); + points.clear(); + } else { + std::stringstream ss; + ss << "Error reading brute force data at line: " << line_num + << " found " << label_and_points.size() << " tab separated entries instead of 2" + << std::endl; + diskann::cerr << ss.str(); + throw diskann::ANNException(ss.str(), -1); + } + line_num++; + } + } +} \ No newline at end of file diff --git a/src/filter_utils.cpp b/src/filter_utils.cpp index 09d740e35..e8ed3f042 100644 --- a/src/filter_utils.cpp +++ b/src/filter_utils.cpp @@ -8,11 +8,11 @@ #include #include -#include #include "filter_utils.h" #include "index.h" #include "parameters.h" #include "utils.h" +#include namespace diskann { @@ -23,6 +23,9 @@ namespace diskann * Each index is saved under the following path: * final_index_path_prefix + "_" + label */ + + const char* NO_LABEL_FOR_POINT = "<>"; + const char FILTERS_LABEL_DELIMITER = ','; template void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, uint32_t R, uint32_t L, float alpha, uint32_t num_threads) @@ -266,7 +269,8 @@ parse_label_file_return_values parse_label_file(path label_data_path, std::strin * as either uint16_t or uint32_t * * Returns two objects via std::tuple: - * 1. a vector of vectors of labels, where the outer vector is indexed by point id + * 1. a vector of vectors of labels, where the outer vector is indexed by point + * id * 2. a set of all labels */ template @@ -323,8 +327,140 @@ std::tuple>, tsl::robin_set> parse_forma return std::make_tuple(pts_to_labels, labels); } -template DISKANN_DLLEXPORT std::tuple>, tsl::robin_set> -parse_formatted_label_file(path label_file); +//TODO: This is a test implementation of adding brute force logic while +//building a filtered index. Must be cleaned up later. +void get_inv_index(const std::string& label_file, const location_t filter_bf_threshold, inverted_index_t& inv_index) { + std::ifstream label_in(label_file); + if (!label_in.is_open()) { + std::stringstream ss; + ss << "Could not open label file: " << label_file << std::endl; + diskann::cerr << ss.str() << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + std::string line; + location_t line_num = 0; + std::vector line_labels; + while (getline(label_in, line)) { + split_string(line, FILTERS_LABEL_DELIMITER, line_labels); + for (auto& label : line_labels) { + if (inv_index.find(label) == inv_index.end()) { + inv_index.insert(std::pair>( + label, std::unordered_set())); + } + inv_index[label].insert(line_num); + } + line_labels.clear(); + line_num++; + } + + diskann::cout << "Built inverted index for filters. Label count: " << inv_index.size(); + auto num_bf_labels = 0; + for (auto& label_and_points : inv_index) { + if (label_and_points.second.size() < filter_bf_threshold) { + num_bf_labels++; + } + } + diskann::cout << " number of sparse labels: " << num_bf_labels << std::endl; + +} + +void get_labels_of_point(const inverted_index_t& inv_index, location_t point, std::vector& labels, location_t sparse_threshold) { + //Maybe we are over optimizing here, but let's see. + for (auto& label_points : inv_index) { + //if a label is sparse then we don't want to consider it for our new label file + //instead it will go into a brute force file. + if (label_points.second.size() >= sparse_threshold) { + if (label_points.second.find(point) != label_points.second.end()) { + labels.push_back(label_points.first); + } + } + } +} + +void write_new_label_file(const inverted_index_t& inv_index, location_t nrows, const std::string& new_label_file, const location_t sparse_threshold) { + std::ofstream label_out(new_label_file); + if (!label_out.is_open()) { + std::stringstream ss; + ss << "Could not open output label file: " << new_label_file << " for writing." << std::endl; + diskann::cerr << ss.str() << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + std::vector labels_of_point; + labels_of_point.reserve(200); //just assuming, won't affect anything. + location_t num_graph_points = 0; + + for (location_t i = 0; i < (location_t)nrows; i++) { + get_labels_of_point(inv_index, i, labels_of_point, sparse_threshold); + if (labels_of_point.size() == 0) { + label_out << NO_LABEL_FOR_POINT << std::endl; + } else { + num_graph_points++; + for (int i = 0; i < labels_of_point.size() - 1; i++) { + label_out << labels_of_point[i] << ","; + } + label_out << labels_of_point[labels_of_point.size() - 1] << std::endl; + } + labels_of_point.clear(); + } + diskann::cout << "New label file: " << new_label_file << ", num graph points: " << num_graph_points << std::endl; + label_out.close(); +} + +void write_brute_force_data(const inverted_index_t& inv_index, const std::string& bf_data_file, + const location_t sparse_threshold) { + std::ofstream bf_out(bf_data_file); + if (!bf_out.is_open()) { + std::stringstream ss; + ss << "Could not open output brute force data file: " << bf_data_file << " for writing." << std::endl; + diskann::cerr << ss.str() << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + for (auto& label_and_points : inv_index) { + if (label_and_points.second.size() < sparse_threshold) { + bf_out << label_and_points.first << "\t"; + + int count = 0; + for (auto id : label_and_points.second) { + bf_out << id; + if (count < label_and_points.second.size() - 1) { + bf_out << ","; + } else { + bf_out << std::endl; + } + count++; + } + } + } + diskann::cout << "Brute force file: " << bf_data_file << std::endl; + bf_out.close(); +} + +template +void separate_brute_forceable_points( + const std::string& base_file, const std::string& label_file, + const location_t filter_bf_threshold, + const std::string& new_lbl_file, + const std::string& bf_data_file) { + + diskann::cout << "Excluding brute forceable points from the dataset for building the diskann graph" << std::endl; + + std::ifstream data_in(base_file, std::ios::binary); + uint64_t nrows, ncols; + get_bin_metadata_impl(data_in, nrows, ncols); + + inverted_index_t inv_index; + get_inv_index(label_file, filter_bf_threshold, inv_index); + + write_new_label_file(inv_index, (location_t)nrows, new_lbl_file, filter_bf_threshold); + write_brute_force_data(inv_index, bf_data_file, filter_bf_threshold); +} + +template DISKANN_DLLEXPORT + std::tuple>, tsl::robin_set> + parse_formatted_label_file(path label_file); template DISKANN_DLLEXPORT std::tuple>, tsl::robin_set> parse_formatted_label_file(path label_file); @@ -352,4 +488,20 @@ generate_label_specific_vector_files_compat(path input_data_path, tsl::robin_map labels_to_number_of_points, std::vector point_ids_to_labels, label_set all_labels); +template DISKANN_DLLEXPORT void separate_brute_forceable_points( + const std::string& base_file, const std::string& label_file, + const location_t filter_bf_threshold, + const std::string& new_lbl_file, + const std::string& bf_data_file); +template DISKANN_DLLEXPORT void separate_brute_forceable_points( + const std::string& base_file, const std::string& label_file, + const location_t filter_bf_threshold, + const std::string& new_lbl_file, + const std::string& bf_data_file); +template DISKANN_DLLEXPORT void separate_brute_forceable_points( + const std::string& base_file, const std::string& label_file, + const location_t filter_bf_threshold, + const std::string& new_lbl_file, + const std::string& bf_data_file); + } // namespace diskann \ No newline at end of file diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index cc7acf615..46ddfc92b 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include "abstract_scratch.h" #include "in_mem_data_store.h" +#include "abstract_scratch.h" +#include #include "utils.h" diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp new file mode 100644 index 000000000..c94dccd9c --- /dev/null +++ b/src/in_mem_filter_store.cpp @@ -0,0 +1,416 @@ +#include "in_mem_filter_store.h" +#include "ann_exception.h" +#include "tsl/robin_map.h" +#include "tsl/robin_set.h" +#include "utils.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace diskann +{ +// TODO: Move to utils.h +DISKANN_DLLEXPORT std::unique_ptr get_file_content(const std::string &filename, uint64_t &file_size); + +template InMemFilterStore::~InMemFilterStore() +{ + if (_pts_to_label_offsets != nullptr) + { + delete[] _pts_to_label_offsets; + _pts_to_label_offsets = nullptr; + } + if (_pts_to_label_counts != nullptr) + { + delete[] _pts_to_label_counts; + _pts_to_label_counts = nullptr; + } + if (_pts_to_labels != nullptr) + { + delete[] _pts_to_labels; + _pts_to_labels = nullptr; + } +} + +template +const std::unordered_map> &InMemFilterStore::get_label_to_medoids() const +{ + return this->_filter_to_medoid_ids; +} + +template +const std::vector &InMemFilterStore::get_medoids_of_label(const LabelT label) +{ + if (_filter_to_medoid_ids.find(label) != _filter_to_medoid_ids.end()) + { + return this->_filter_to_medoid_ids[label]; + } + else + { + std::stringstream ss; + ss << "Could not find " << label << " in filters_to_medoid_ids map." << std::endl; + diskann::cerr << ss.str(); + throw ANNException(ss.str(), -1); + } +} + +template void InMemFilterStore::set_universal_label(const LabelT univ_label) +{ + _universal_filter_label = univ_label; + _use_universal_label = true; +} + +// Load functions for SEARCH START +template +bool InMemFilterStore::load(const std::string &disk_index_file) { + std::string labels_file = disk_index_file + "_labels.txt"; + std::string labels_to_medoids = disk_index_file + "_labels_to_medoids.txt"; + std::string dummy_map_file = disk_index_file + "_dummy_map.txt"; + std::string labels_map_file = disk_index_file + "_labels_map.txt"; + std::string univ_label_file = disk_index_file + "_universal_label.txt"; + std::string brute_force_data_file = disk_index_file + "_brute_force.txt"; + std::string bf_excluded_label_file = disk_index_file + "_non_brute_force_labels.txt"; + + size_t num_pts_in_label_file = 0; + + // TODO: Check for encoding issues here. We are opening files as binary and + // reading them as bytes, not sure if that can cause an issue with UTF + // encodings. + bool has_filters = true; + if (false == load_file_and_parse(labels_file, &InMemFilterStore::load_label_file)) + { + diskann::cout << "Index does not have filter data. " << std::endl; + return false; + } + if (false == parse_stream(labels_map_file, &InMemFilterStore::load_label_map)) + { + diskann::cerr << "Failed to find file: " << labels_map_file << " while labels_file exists." << std::endl; + return false; + } + + if (false == parse_stream(labels_to_medoids, &InMemFilterStore::load_labels_to_medoids)) + { + diskann::cerr << "Failed to find file: " << labels_to_medoids << " while labels file exists." << std::endl; + return false; + } + // missing universal label file is NOT an error. + load_file_and_parse(univ_label_file, &InMemFilterStore::parse_universal_label); + + // missing dummy map file is also NOT an error. + parse_stream(dummy_map_file, &InMemFilterStore::load_dummy_map); + _is_valid = true; + return _is_valid; +} + +template bool InMemFilterStore::has_filter_support() const +{ + return _is_valid; +} + +// TODO: Improve this to not load the entire file in memory +template void InMemFilterStore::load_label_file(const std::string_view &label_file_content) +{ + std::string line; + uint32_t line_cnt = 0; + + //TODO: This code is very inefficient because it reads the label file twice - + //once for computing stats and then for loading the labels. Must merge the + //two reads. + uint32_t num_pts_in_label_file; + uint32_t num_total_labels; + get_label_file_metadata(label_file_content, num_pts_in_label_file, + num_total_labels); + + _num_points = num_pts_in_label_file; + + _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; + _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; + _pts_to_labels = new LabelT[num_total_labels]; + uint32_t labels_seen_so_far = 0; + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + size_t file_size = label_file_content.size(); + + while (cur_pos < file_size && cur_pos != std::string_view::npos) + { + next_pos = label_file_content.find('\n', cur_pos); + if (next_pos == std::string_view::npos) + { + break; + } + + _pts_to_label_offsets[line_cnt] = labels_seen_so_far; + uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; + num_lbls_in_cur_pt = 0; + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string_view::npos) + { + next_lbl_pos = label_file_content.find(',', lbl_pos); + if (next_lbl_pos == std::string_view::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + // TODO: SHOULD NOT EXPECT label_file_content TO BE NULL_TERMINATED + label_str.assign(label_file_content.data() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); + _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; + num_lbls_in_cur_pt++; + + // move to next label + lbl_pos = next_lbl_pos + 1; + } + + // move to next line + cur_pos = next_pos + 1; + + if (num_lbls_in_cur_pt == 0) + { + diskann::cout << "No label found for point " << line_cnt << std::endl; + exit(-1); + } + + line_cnt++; + } + + // TODO: We need to check if the number of labels and the number of points + // is as expected. Maybe add the check in PQFlashIndex? + // num_points_labels = line_cnt; +} + +template +void InMemFilterStore::load_labels_to_medoids(std::basic_istream &medoid_stream) +{ + std::string line, token; + + _filter_to_medoid_ids.clear(); + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + std::vector medoids; + LabelT label; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + label = (LabelT)std::stoul(token); + else + medoids.push_back((uint32_t)stoul(token)); + cnt++; + } + _filter_to_medoid_ids[label].swap(medoids); + } +} + +template void InMemFilterStore::load_label_map(std::basic_istream &map_reader) +{ + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = (LabelT)std::stoul(token); + _label_map[label_str] = token_as_num; + } +} + +template void InMemFilterStore::parse_universal_label(const std::string_view &content) +{ + LabelT label_as_num = (LabelT)std::stoul(std::string(content)); + this->set_universal_label(label_as_num); +} + +template void InMemFilterStore::load_dummy_map(std::basic_istream &dummy_map_stream) +{ + std::string line, token; + + while (std::getline(dummy_map_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + dummy_id = (uint32_t)stoul(token); + else + real_id = (uint32_t)stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; + + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector(); + + _real_to_dummy_map[real_id].emplace_back(dummy_id); + } + diskann::cout << "Loaded dummy map" << std::endl; +} + +template +void InMemFilterStore::generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) +{ + std::random_device rd; + labels.clear(); + labels.resize(num_labels); + + uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; + std::mt19937 gen(rd()); + if (num_total_labels == 0) + { + std::stringstream stream; + stream << "No labels found in data. Not sampling random labels "; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + std::uniform_int_distribution dis(0, num_total_labels - 1); + +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) + for (int64_t i = 0; i < num_labels; i++) + { + uint64_t rnd_loc = dis(gen); + labels[i] = (LabelT)_pts_to_labels[rnd_loc]; + } +} + +template void InMemFilterStore::reset_stream_for_reading(std::basic_istream &infile) +{ + infile.clear(); + infile.seekg(0); +} + +template +void InMemFilterStore::get_label_file_metadata(const std::string_view &fileContent, uint32_t &num_pts, + uint32_t &num_total_labels) +{ + num_pts = 0; + num_total_labels = 0; + + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; + } + + cur_pos = next_pos + 1; + + num_pts++; + } + + diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels + << std::endl; +} + +template +bool InMemFilterStore::parse_stream(const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)) +{ + if (file_exists(filename)) + { + std::ifstream stream(filename); + if (false == stream.fail()) + { + std::invoke(parse_fn, this, stream); + return true; + } + else + { + std::stringstream ss; + ss << "Could not open file: " << filename << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + } + else + { + return false; + } +} + +template +bool InMemFilterStore::load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)) +{ + if (file_exists(filename)) + { + size_t file_size = 0; + auto file_content_ptr = get_file_content(filename, file_size); + std::string_view content_as_str(file_content_ptr.get(), file_size); + std::invoke(parse_fn, this, content_as_str); + return true; + } + else + { + return false; + } +} + +std::unique_ptr get_file_content(const std::string &filename, uint64_t &file_size) +{ + std::ifstream infile(filename, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + infile.seekg(0, std::ios::end); + file_size = infile.tellg(); + + auto buffer = new char[file_size]; + infile.seekg(0, std::ios::beg); + infile.read(buffer, file_size); + + return std::unique_ptr(buffer); +} +// Load functions for SEARCH END +template class InMemFilterStore; +template class InMemFilterStore; +template class InMemFilterStore; + +} // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index bf93344fa..f567ea8e6 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include - -#include - +#include "ann_exception.h" #include "boost/dynamic_bitset.hpp" #include "index_factory.h" #include "memory_mapper.h" +#include "tag_uint128.h" #include "timer.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" +#include "utils.h" #include "windows_customizations.h" -#include "tag_uint128.h" +#include +#include #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -21,6 +21,7 @@ #include #endif +#include "filter_utils.h" #include "index.h" #define MAX_POINTS_FOR_USING_BITSET 10000000 @@ -92,15 +93,15 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrsearch_list_size; - _indexingRange = index_config.index_write_params->max_degree; - _indexingMaxC = index_config.index_write_params->max_occlusion_size; - _indexingAlpha = index_config.index_write_params->alpha; - _filterIndexingQueueSize = index_config.index_write_params->filter_list_size; - _indexingThreads = index_config.index_write_params->num_threads; - _saturate_graph = index_config.index_write_params->saturate_graph; + if (index_config.index_write_params != nullptr) { + _indexingQueueSize = index_config.index_write_params->search_list_size; + _indexingRange = index_config.index_write_params->max_degree; + _indexingMaxC = index_config.index_write_params->max_occlusion_size; + _indexingAlpha = index_config.index_write_params->alpha; + _filter_indexing_queue_size = + index_config.index_write_params->filter_list_size; + _indexingThreads = index_config.index_write_params->num_threads; + _saturate_graph = index_config.index_write_params->saturate_graph; if (index_config.index_search_params != nullptr) { @@ -331,7 +332,8 @@ void Index::save(const char *filename, bool compact_before_save } label_writer.close(); - // write compacted raw_labels if data hence _location_to_labels was also compacted + // write compacted raw_labels if data hence _location_to_labels was also + // compacted if (compact_before_save && _dynamic_index) { _label_map = load_label_map(std::string(filename) + "_labels_map.txt"); @@ -732,8 +734,8 @@ template int Index template uint32_t Index::calculate_entry_point() { - // REFACTOR TODO: This function does not support multi-threaded calculation of medoid. - // Must revisit if perf is a concern. + // REFACTOR TODO: This function does not support multi-threaded calculation of + // medoid. Must revisit if perf is a concern. return _data_store->calculate_medoid(); } @@ -1303,19 +1305,18 @@ template void Index> manager(_query_scratch); - auto scratch = manager.scratch_space(); - std::vector pruned_list; - if (_filtered_index) - { - search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); - } - else - { - search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch); - } - assert(pruned_list.size() > 0); + // Find and add appropriate graph edges + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + std::vector pruned_list; + if (_filtered_index) { + search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, + true, _filter_indexing_queue_size); + } else { + search_for_point_and_prune(node, _indexingQueueSize, pruned_list, + scratch); + } + assert(pruned_list.size() > 0); { LockGuard guard(_locks[node]); @@ -1655,13 +1656,15 @@ void Index::build(const char *filename, const size_t num_points throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - // REFACTOR PQ TODO: We can remove this if and add a check in the InMemDataStore - // to not populate_data if it has been called once. + // REFACTOR PQ TODO: We can remove this if and add a check in the + // InMemDataStore to not populate_data if it has been called once. if (_pq_dist) { #ifdef EXEC_ENV_OLS std::stringstream ss; - ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS is defined)" << std::endl; + ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS " + "is defined)" + << std::endl; diskann::cerr << ss.str() << std::endl; throw ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); #else @@ -1819,20 +1822,22 @@ void Index::parse_label_file(const std::string &label_file, siz infile.seekg(0, std::ios::beg); line_cnt = 0; - while (std::getline(infile, line)) - { - std::istringstream iss(line); - std::vector lbls(0); - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) - { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); - lbls.push_back(token_as_num); - _labels.insert(token_as_num); - } + while (std::getline(infile, line)) { + if (line.find(NO_LABEL_FOR_POINT) != std::string::npos) { + line_cnt++; + continue; + } + std::istringstream iss(line); + std::vector lbls(0); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = (LabelT)std::stoul(token); + lbls.push_back(token_as_num); + _labels.insert(token_as_num); + } std::sort(lbls.begin(), lbls.end()); _location_to_labels[line_cnt] = lbls; @@ -1864,8 +1869,8 @@ void Index::build_filtered_index(const char *filename, const st size_t num_points_labels = 0; parse_label_file(label_file, - num_points_labels); // determines medoid for each label and identifies - // the points to label mapping + num_points_labels); // determines medoid for each label and + // identifies the points to label mapping std::unordered_map> label_to_points; @@ -2891,10 +2896,10 @@ int Index::insert_point(const T *point, const TagT tag, const s { if (_frozen_pts_used >= _num_frozen_pts) { - throw ANNException( - "Error: For dynamic filtered index, the number of frozen points should be atleast equal " - "to number of unique labels.", - -1); + throw ANNException("Error: For dynamic filtered index, the number of " + "frozen points should be atleast equal " + "to number of unique labels.", + -1); } auto fz_location = (int)(_max_points) + _frozen_pts_used; // as first _fz_point @@ -2950,7 +2955,8 @@ int Index::insert_point(const T *point, const TagT tag, const s // Insert tag and mapping to location if (_enable_tags) { - // if tags are enabled and tag is already inserted. so we can't reuse that tag. + // if tags are enabled and tag is already inserted. so we can't reuse that + // tag. if (_tag_to_location.find(tag) != _tag_to_location.end()) { release_location(location); @@ -2964,20 +2970,22 @@ int Index::insert_point(const T *point, const TagT tag, const s _data_store->set_vector(location, point); // update datastore - // Find and add appropriate graph edges - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - std::vector pruned_list; // it is the set best candidates to connect to this point - if (_filtered_index) - { - // when filtered the best_candidates will share the same label ( label_present > distance) - search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); - } - else - { - search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch); - } - assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e frozen point acting as medoid) + // Find and add appropriate graph edges + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + std::vector + pruned_list; // it is the set best candidates to connect to this point + if (_filtered_index) { + // when filtered the best_candidates will share the same label ( + // label_present > distance) + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, + scratch, true, _filter_indexing_queue_size); + } else { + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, + scratch); + } + assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e + // frozen point acting as medoid) { std::shared_lock tlock(_tag_lock, std::defer_lock); diff --git a/src/index_factory.cpp b/src/index_factory.cpp index 35790f8d6..08b89da1d 100644 --- a/src/index_factory.cpp +++ b/src/index_factory.cpp @@ -45,7 +45,8 @@ void IndexFactory::check_config() _config->tag_type != "uint64") { throw ANNException("ERROR: invalid data type : + " + _config->tag_type + - " is not supported. please select from [int32, uint32, int64, uint64]", + " is not supported. please select from [int32, " + "uint32, int64, uint64]", -1); } } @@ -109,7 +110,8 @@ std::shared_ptr> IndexFactory::construct_pq_datastore(DataStoreSt return std::make_shared>(dimension, (location_t)(num_points), num_pq_chunks, std::move(distance_fn), std::move(quantized_distance_fn)); default: - // REFACTOR TODO: We do support diskPQ - so we may need to add a new class for SSDPQDataStore! + // REFACTOR TODO: We do support diskPQ - so we may need to add a new class + // for SSDPQDataStore! break; } return nullptr; @@ -120,7 +122,8 @@ std::unique_ptr IndexFactory::create_instance() { size_t num_points = _config->max_points + _config->num_frozen_pts; size_t dim = _config->dimension; - // auto graph_store = construct_graphstore(_config->graph_strategy, num_points); + // auto graph_store = construct_graphstore(_config->graph_strategy, + // num_points); auto data_store = construct_datastore(_config->data_strategy, num_points, dim, _config->metric); std::shared_ptr> pq_data_store = nullptr; @@ -140,8 +143,9 @@ std::unique_ptr IndexFactory::create_instance() std::unique_ptr graph_store = construct_graphstore(_config->graph_strategy, num_points + _config->num_frozen_pts, max_reserve_degree); - // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK and must construct - // in-mem and on-disk PQDataStore if strategy == ONDISK and diskPQ is required. + // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK + // and must construct in-mem and on-disk PQDataStore if strategy == ONDISK and + // diskPQ is required. return std::make_unique>(*_config, data_store, std::move(graph_store), pq_data_store); } @@ -185,7 +189,9 @@ std::unique_ptr IndexFactory::create_instance(const std::string & return create_instance(label_type); } else - throw ANNException("Error: unsupported tag_type please choose from [int32/uint32/int64/uint64]", -1); + throw ANNException("Error: unsupported tag_type please choose from " + "[int32/uint32/int64/uint64]", + -1); } template @@ -203,11 +209,17 @@ std::unique_ptr IndexFactory::create_instance(const std::string & throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1); } -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); +// template DISKANN_DLLEXPORT std::shared_ptr> +// IndexFactory::construct_datastore( +// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric +// m); +// template DISKANN_DLLEXPORT std::shared_ptr> +// IndexFactory::construct_datastore( +// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric +// m); +// template DISKANN_DLLEXPORT std::shared_ptr> +// IndexFactory::construct_datastore( +// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric +// m); } // namespace diskann diff --git a/src/linux_aligned_file_reader.cpp b/src/linux_aligned_file_reader.cpp index 31bf5f827..94e14dc08 100644 --- a/src/linux_aligned_file_reader.cpp +++ b/src/linux_aligned_file_reader.cpp @@ -3,11 +3,11 @@ #include "linux_aligned_file_reader.h" +#include "tsl/robin_map.h" +#include "utils.h" #include #include #include -#include "tsl/robin_map.h" -#include "utils.h" #define MAX_EVENTS 1024 namespace @@ -149,7 +149,9 @@ void LinuxAlignedFileReader::register_thread() lk.unlock(); if (ret == -EAGAIN) { - std::cerr << "io_setup() failed with EAGAIN: Consider increasing /proc/sys/fs/aio-max-nr" << std::endl; + std::cerr << "io_setup() failed with EAGAIN: Consider increasing " + "/proc/sys/fs/aio-max-nr" + << std::endl; } else { diff --git a/src/math_utils.cpp b/src/math_utils.cpp index 7481da848..5ce66fb2e 100644 --- a/src/math_utils.cpp +++ b/src/math_utils.cpp @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include "logger.h" +#include "utils.h" #include #include #include #include -#include "logger.h" -#include "utils.h" namespace math_utils { diff --git a/src/memory_mapper.cpp b/src/memory_mapper.cpp index d1c5ef984..819df7fec 100644 --- a/src/memory_mapper.cpp +++ b/src/memory_mapper.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include "logger.h" #include "memory_mapper.h" +#include "logger.h" #include #include diff --git a/src/partition.cpp b/src/partition.cpp index 570d45c7d..1428eb801 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -7,20 +7,20 @@ #include #include -#include #include "tsl/robin_map.h" #include "tsl/robin_set.h" +#include #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif -#include "utils.h" -#include "math_utils.h" #include "index.h" -#include "parameters.h" +#include "math_utils.h" #include "memory_mapper.h" +#include "parameters.h" #include "partition.h" +#include "utils.h" #ifdef _WINDOWS #include #endif diff --git a/src/pq.cpp b/src/pq.cpp index d2b545c79..d1cc8e861 100644 --- a/src/pq.cpp +++ b/src/pq.cpp @@ -5,9 +5,9 @@ #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif -#include "pq.h" -#include "partition.h" #include "math_utils.h" +#include "partition.h" +#include "pq.h" #include "tsl/robin_map.h" // block size for reading/processing large files and matrices in blocks @@ -354,8 +354,9 @@ void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_n // make_zero_mean is false by default. // These assumptions allow to make the function much simpler and avoid storing // array of chunk_offsets and centroids. -// The compiler pragma for multi-threading support is removed from this implementation -// for the purpose of integration into systems that strictly control resource allocation. +// The compiler pragma for multi-threading support is removed from this +// implementation for the purpose of integration into systems that strictly +// control resource allocation. int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, size_t num_pq_chunks, std::vector &pivot_data_vector) { @@ -771,18 +772,19 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, uint32 return 0; } -// generate_pq_data_from_pivots_simplified is a simplified version of generate_pq_data_from_pivots. -// Input is provided in the in-memory buffers data and pivot_data. -// Output is stored in the in-memory buffer pq. -// Simplification is based on the following assumptions: +// generate_pq_data_from_pivots_simplified is a simplified version of +// generate_pq_data_from_pivots. Input is provided in the in-memory buffers data +// and pivot_data. Output is stored in the in-memory buffer pq. Simplification +// is based on the following assumptions: // supporting only float data type // dim % num_pq_chunks == 0, which results in a fixed chunk_size // num_centers == 256 by default // make_zero_mean is false by default. // These assumptions allow to make the function much simpler and avoid using // array of chunk_offsets and centroids. -// The compiler pragma for multi-threading support is removed from this implementation -// for the purpose of integration into systems that strictly control resource allocation. +// The compiler pragma for multi-threading support is removed from this +// implementation for the purpose of integration into systems that strictly +// control resource allocation. int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, const float *pivot_data, const size_t pivots_num, const size_t dim, const size_t num_pq_chunks, std::vector &pq) diff --git a/src/pq_data_store.cpp b/src/pq_data_store.cpp index c47c16705..16e94d40e 100644 --- a/src/pq_data_store.cpp +++ b/src/pq_data_store.cpp @@ -1,10 +1,10 @@ #include -#include "pq_data_store.h" +#include "distance.h" #include "pq.h" +#include "pq_data_store.h" #include "pq_scratch.h" #include "utils.h" -#include "distance.h" namespace diskann { diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index d9ad50617..80c1c7460 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -3,11 +3,12 @@ #include "common_includes.h" -#include "timer.h" +#include "cosine_similarity.h" +#include "in_mem_filter_store.h" #include "pq.h" -#include "pq_scratch.h" #include "pq_flash_index.h" -#include "cosine_similarity.h" +#include "pq_scratch.h" +#include "timer.h" #ifdef _WINDOWS #include "windows_aligned_file_reader.h" @@ -37,8 +38,10 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe { if (std::is_floating_point::value) { - diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed " - "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we " + diskann::cout << "Since data is floating point, we assume that it has " + "been appropriately pre-processed " + "(normalization for cosine, and convert-to-l2 by " + "adding extra dimension for MIPS). So we " "shall invoke an l2 distance function." << std::endl; metric_to_invoke = diskann::Metric::L2; @@ -53,6 +56,7 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); this->_dist_cmp_float.reset(diskann::get_distance_function(metric_to_invoke)); + this->_filter_store = std::make_unique>(); } template PQFlashIndex::~PQFlashIndex() @@ -73,6 +77,12 @@ template PQFlashIndex::~PQFlashIndex() diskann::aligned_free(_coord_cache_buf); } + if (_medoids != nullptr) + { + delete[] _medoids; + _medoids = nullptr; + } + if (_load_flag) { diskann::cout << "Clearing scratch" << std::endl; @@ -81,22 +91,6 @@ template PQFlashIndex::~PQFlashIndex() this->reader->deregister_all_threads(); reader->close(); } - if (_pts_to_label_offsets != nullptr) - { - delete[] _pts_to_label_offsets; - } - if (_pts_to_label_counts != nullptr) - { - delete[] _pts_to_label_counts; - } - if (_pts_to_labels != nullptr) - { - delete[] _pts_to_labels; - } - if (_medoids != nullptr) - { - delete[] _medoids; - } } template inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) @@ -269,7 +263,8 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin #endif if (num_nodes_to_cache >= this->_num_points) { - // for small num_points and big num_nodes_to_cache, use below way to get the node_list quickly + // for small num_points and big num_nodes_to_cache, use below way to get + // the node_list quickly node_list.resize(this->_num_points); for (uint32_t i = 0; i < this->_num_points; ++i) { @@ -312,19 +307,21 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin bool filtered_search = false; std::vector random_query_filters(sample_num); - if (_filter_to_medoid_ids.size() != 0) + if (this->_filter_index) { filtered_search = true; - generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); + _filter_store->generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); } #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) for (int64_t i = 0; i < (int64_t)sample_num; i++) { auto &label_for_search = random_query_filters[i]; - // run a search on the sample query with a random label (sampled from base label distribution), and it will - // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the - // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. + // run a search on the sample query with a random label (sampled from base + // label distribution), and it will concurrently update the + // node_visit_counter to track most visited nodes. The last false is to + // not use the "use_reorder_data" option which enables a final reranking + // if the disk index itself contains only PQ data. cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); } @@ -374,9 +371,10 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: cur_level->insert(_medoids[miter]); } - if ((_filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) + auto filter_to_medoid_ids = _filter_store->get_label_to_medoids(); + if ((filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) { - for (auto &x : _filter_to_medoid_ids) + for (auto &x : filter_to_medoid_ids) { for (auto &y : x.second) { @@ -533,226 +531,6 @@ template void PQFlashIndex::use_medoids } } -template -void PQFlashIndex::generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads) -{ - std::random_device rd; - labels.clear(); - labels.resize(num_labels); - - uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; - std::mt19937 gen(rd()); - if (num_total_labels == 0) - { - std::stringstream stream; - stream << "No labels found in data. Not sampling random labels "; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - std::uniform_int_distribution dis(0, num_total_labels - 1); - -#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < num_labels; i++) - { - uint64_t rnd_loc = dis(gen); - labels[i] = (LabelT)_pts_to_labels[rnd_loc]; - } -} - -template -std::unordered_map PQFlashIndex::load_label_map(std::basic_istream &map_reader) -{ - std::unordered_map string_to_int_mp; - std::string line, token; - LabelT token_as_num; - std::string label_str; - while (std::getline(map_reader, line)) - { - std::istringstream iss(line); - getline(iss, token, '\t'); - label_str = token; - getline(iss, token, '\t'); - token_as_num = (LabelT)std::stoul(token); - string_to_int_mp[label_str] = token_as_num; - } - return string_to_int_mp; -} - -template -LabelT PQFlashIndex::get_converted_label(const std::string &filter_label) -{ - if (_label_map.find(filter_label) != _label_map.end()) - { - return _label_map[filter_label]; - } - if (_use_universal_label) - { - return _universal_filter_label; - } - std::stringstream stream; - stream << "Unable to find label in the Label Map"; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); -} - -template -void PQFlashIndex::reset_stream_for_reading(std::basic_istream &infile) -{ - infile.clear(); - infile.seekg(0); -} - -template -void PQFlashIndex::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels) -{ - num_pts = 0; - num_total_labels = 0; - - size_t file_size = fileContent.length(); - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) - { - next_pos = fileContent.find('\n', cur_pos); - if (next_pos == std::string::npos) - { - break; - } - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) - { - next_lbl_pos = fileContent.find(',', lbl_pos); - if (next_lbl_pos == std::string::npos) // the last label - { - next_lbl_pos = next_pos; - } - - num_total_labels++; - - lbl_pos = next_lbl_pos + 1; - } - - cur_pos = next_pos + 1; - - num_pts++; - } - - diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels - << std::endl; -} - -template -inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT label_id) -{ - uint32_t start_vec = _pts_to_label_offsets[point_id]; - uint32_t num_lbls = _pts_to_label_counts[point_id]; - bool ret_val = false; - for (uint32_t i = 0; i < num_lbls; i++) - { - if (_pts_to_labels[start_vec + i] == label_id) - { - ret_val = true; - break; - } - } - return ret_val; -} - -template -void PQFlashIndex::parse_label_file(std::basic_istream &infile, size_t &num_points_labels) -{ - infile.seekg(0, std::ios::end); - size_t file_size = infile.tellg(); - - std::string buffer(file_size, ' '); - - infile.seekg(0, std::ios::beg); - infile.read(&buffer[0], file_size); - - std::string line; - uint32_t line_cnt = 0; - - uint32_t num_pts_in_label_file; - uint32_t num_total_labels; - get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels); - - _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; - _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; - _pts_to_labels = new LabelT[num_total_labels]; - uint32_t labels_seen_so_far = 0; - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) - { - next_pos = buffer.find('\n', cur_pos); - if (next_pos == std::string::npos) - { - break; - } - - _pts_to_label_offsets[line_cnt] = labels_seen_so_far; - uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; - num_lbls_in_cur_pt = 0; - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) - { - next_lbl_pos = buffer.find(',', lbl_pos); - if (next_lbl_pos == std::string::npos) // the last label in the whole file - { - next_lbl_pos = next_pos; - } - - if (next_lbl_pos > next_pos) // the last label in one line, just read to the end - { - next_lbl_pos = next_pos; - } - - label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); - if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? - { - label_str.erase(label_str.length() - 1); - } - - LabelT token_as_num = (LabelT)std::stoul(label_str); - _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; - num_lbls_in_cur_pt++; - - // move to next label - lbl_pos = next_lbl_pos + 1; - } - - // move to next line - cur_pos = next_pos + 1; - - if (num_lbls_in_cur_pt == 0) - { - diskann::cout << "No label found for point " << line_cnt << std::endl; - exit(-1); - } - - line_cnt++; - } - - num_points_labels = line_cnt; - reset_stream_for_reading(infile); -} - -template void PQFlashIndex::set_universal_label(const LabelT &label) -{ - _use_universal_label = true; - _universal_filter_label = label; -} - #ifdef EXEC_ENV_OLS template int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) @@ -791,12 +569,6 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; - std::string labels_file = std ::string(_disk_index_file) + "_labels.txt"; - std::string labels_to_medoids = std ::string(_disk_index_file) + "_labels_to_medoids.txt"; - std::string dummy_map_file = std ::string(_disk_index_file) + "_dummy_map.txt"; - std::string labels_map_file = std ::string(_disk_index_file) + "_labels_map.txt"; - size_t num_pts_in_label_file = 0; - size_t pq_file_dim, pq_file_num_centroids; #ifdef EXEC_ENV_OLS get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); @@ -824,147 +596,30 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #else diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #endif - this->_num_points = npts_u64; this->_n_chunks = nchunks_u64; -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_file)) - { - FileContent &content_labels = files.getContent(labels_file); - std::stringstream infile(std::string((const char *)content_labels._content, content_labels._size)); -#else - if (file_exists(labels_file)) - { - std::ifstream infile(labels_file, std::ios::binary); - if (infile.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); - } -#endif - parse_label_file(infile, num_pts_in_label_file); - assert(num_pts_in_label_file == this->_num_points); -#ifndef EXEC_ENV_OLS - infile.close(); -#endif - -#ifdef EXEC_ENV_OLS - FileContent &content_labels_map = files.getContent(labels_map_file); - std::stringstream map_reader(std::string((const char *)content_labels_map._content, content_labels_map._size)); -#else - std::ifstream map_reader(labels_map_file); -#endif - _label_map = load_label_map(map_reader); - -#ifndef EXEC_ENV_OLS - map_reader.close(); -#endif - -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_to_medoids)) - { - FileContent &content_labels_to_meoids = files.getContent(labels_to_medoids); - std::stringstream medoid_stream( - std::string((const char *)content_labels_to_meoids._content, content_labels_to_meoids._size)); -#else - if (file_exists(labels_to_medoids)) - { - std::ifstream medoid_stream(labels_to_medoids); - assert(medoid_stream.is_open()); -#endif - std::string line, token; - - _filter_to_medoid_ids.clear(); - try - { - while (std::getline(medoid_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - std::vector medoids; - LabelT label; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - label = (LabelT)std::stoul(token); - else - medoids.push_back((uint32_t)stoul(token)); - cnt++; - } - _filter_to_medoid_ids[label].swap(medoids); - } - } - catch (std::system_error &e) - { - throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); - } - } - std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt"; - -#ifdef EXEC_ENV_OLS - if (files.fileExists(univ_label_file)) - { - FileContent &content_univ_label = files.getContent(univ_label_file); - std::stringstream universal_label_reader( - std::string((const char *)content_univ_label._content, content_univ_label._size)); -#else - if (file_exists(univ_label_file)) + _filter_store = std::make_unique>(); + try + { + _filter_index = _filter_store->load(_disk_index_file); + if (_filter_index) { - std::ifstream universal_label_reader(univ_label_file); - assert(universal_label_reader.is_open()); -#endif - std::string univ_label; - universal_label_reader >> univ_label; -#ifndef EXEC_ENV_OLS - universal_label_reader.close(); -#endif - LabelT label_as_num = (LabelT)std::stoul(univ_label); - set_universal_label(label_as_num); + diskann::cout << "Index has filter support. " << std::endl; } - -#ifdef EXEC_ENV_OLS - if (files.fileExists(dummy_map_file)) - { - FileContent &content_dummy_map = files.getContent(dummy_map_file); - std::stringstream dummy_map_stream( - std::string((const char *)content_dummy_map._content, content_dummy_map._size)); -#else - if (file_exists(dummy_map_file)) + else { - std::ifstream dummy_map_stream(dummy_map_file); - assert(dummy_map_stream.is_open()); -#endif - std::string line, token; - - while (std::getline(dummy_map_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - dummy_id = (uint32_t)stoul(token); - else - real_id = (uint32_t)stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); - - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } -#ifndef EXEC_ENV_OLS - dummy_map_stream.close(); -#endif - diskann::cout << "Loaded dummy map" << std::endl; + diskann::cout << "Index does not have filter support." << std::endl; } } + catch (diskann::ANNException &ex) + { + // If filter_store=>load() returns false, it means filters are not + // enabled. If it throws, it means there was an error in processing a + // filter index. + diskann::cerr << "Filter index load failed because: " << ex.what() << std::endl; + return false; + } #ifdef EXEC_ENV_OLS _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); @@ -1026,8 +681,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream index_metadata(_disk_index_file, std::ios::binary); #endif - uint32_t nr, nc; // metadata itself is stored as bin format (nr is number of - // metadata, nc should be 1) + uint32_t nr, nc; // metadata itself is stored as bin format (nr is number + // of metadata, nc should be 1) READ_U32(index_metadata, nr); READ_U32(index_metadata, nc); @@ -1271,7 +926,6 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, @@ -1294,8 +948,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t float *query_rotated = pq_query_scratch->rotated_query; // normalization step. for cosine, we simply normalize the query - // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to - // convert MIPS to L2 search + // for mips, we normalize the first d-1 dims, and add a 0 for last dim, + // since an extra coordinate was used to convert MIPS to L2 search if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE) { uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1); @@ -1335,8 +989,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); // query <-> PQ chunk centers distances - _pq_table.preprocess_query(query_rotated); // center the query and rotate if - // we have a rotation matrix + _pq_table.preprocess_query(query_rotated); // center the query and rotate + // if we have a rotation matrix float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; _pq_table.populate_chunk_distances(query_rotated, pq_dists); @@ -1374,13 +1028,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } else { - if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) + const auto &medoid_ids = _filter_store->get_medoids_of_label(filter_label); + if (medoid_ids.size() > 0) + // if (_filter_to_medoid_ids.find(filter_label) != + // _filter_to_medoid_ids.end()) { - const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + // const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) { - // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance - // as approximation to decide closest medoid matching the query filter. + // for filtered index, we dont store global centroid data as for + // unfiltered index, so we use PQ distance as approximation to decide + // closest medoid matching the query filter. compute_dists(&medoid_ids[cur_m], 1, dist_scratch); float cur_expanded_dist = dist_scratch[0]; if (cur_expanded_dist < best_dist) @@ -1520,11 +1179,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // unfiltered search, but filtered index! + if (!use_filter && _filter_store->is_dummy_point(id)) continue; - if (use_filter && !(point_has_label(id, filter_label)) && - (!_use_universal_label || !point_has_label(id, _universal_filter_label))) + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) continue; cmps++; float dist = dist_scratch[m]; @@ -1583,11 +1246,14 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filter && _filter_store->is_dummy_point(id)) continue; - if (use_filter && !(point_has_label(id, filter_label)) && - (!_use_universal_label || !point_has_label(id, _universal_filter_label))) + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) continue; cmps++; float dist = dist_scratch[m]; @@ -1668,9 +1334,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t { indices[i] = full_retset[i].id; auto key = (uint32_t)indices[i]; - if (_dummy_pts.find(key) != _dummy_pts.end()) + if (_filter_store->is_dummy_point(key)) { - indices[i] = _dummy_to_real_map[key]; + indices[i] = _filter_store->get_real_point_for_dummy(key); } if (distances != nullptr) @@ -1752,6 +1418,12 @@ template diskann::Metric PQFlashIndex:: return this->metric; } +template +LabelT PQFlashIndex::get_converted_label(const std::string &filter_label) +{ + return _filter_store->get_converted_label(filter_label); +} + #ifdef EXEC_ENV_OLS template char *PQFlashIndex::getHeaderBytes() { diff --git a/src/pq_l2_distance.cpp b/src/pq_l2_distance.cpp index c08744c35..9168d26be 100644 --- a/src/pq_l2_distance.cpp +++ b/src/pq_l2_distance.cpp @@ -1,6 +1,6 @@ -#include "pq.h" #include "pq_l2_distance.h" +#include "pq.h" #include "pq_scratch.h" // block size for reading/processing large files and matrices in blocks diff --git a/src/scratch.cpp b/src/scratch.cpp index 1f8a34bb1..2203dcbc3 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include +#include -#include "scratch.h" #include "pq_scratch.h" +#include "scratch.h" namespace diskann { diff --git a/src/windows_aligned_file_reader.cpp b/src/windows_aligned_file_reader.cpp index 3650b928a..4ddd50902 100644 --- a/src/windows_aligned_file_reader.cpp +++ b/src/windows_aligned_file_reader.cpp @@ -4,8 +4,8 @@ #ifdef _WINDOWS #ifndef USE_BING_INFRA #include "windows_aligned_file_reader.h" -#include #include "utils.h" +#include #include #define SECTOR_LEN 4096