diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h new file mode 100644 index 000000000..64739845e --- /dev/null +++ b/include/abstract_filter_store.h @@ -0,0 +1,71 @@ +#pragma once +#include "common_includes.h" +#include "utils.h" +#include + +namespace diskann +{ + +enum class FilterMatchStrategy +{ + SET_INTERSECTION +}; +// This class is responsible for filter actions in index, and should not be used outside. +template class AbstractFilterStore +{ + public: + AbstractFilterStore(const size_t num_points); + virtual ~AbstractFilterStore() = default; + + // needs some internal lock + abstract implementation + virtual bool detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels, + const FilterMatchStrategy strategy = FilterMatchStrategy::SET_INTERSECTION) = 0; + + DISKANN_DLLEXPORT virtual const std::vector &get_labels_by_location(const location_t point_id) = 0; + virtual void set_labels_to_location(const location_t location, const std::vector &labels) = 0; + virtual void swap_labels(const location_t location_first, const location_t location_second) = 0; + + virtual const tsl::robin_set &get_all_label_set() = 0; + virtual void add_to_label_set(const label_type &label) = 0; + // Throws: out of range exception + virtual void add_label_to_location(const location_t point_id, const label_type label) = 0; + // returns internal mapping for given raw_label + DISKANN_DLLEXPORT virtual label_type get_numeric_label(const std::string &raw_label) = 0; + + // TODO: in future we may accept a set or vector of universal labels + // DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0; + virtual void set_universal_label(const std::string &universal_labels) = 0; + virtual std::pair get_universal_label() = 0; + + // takes raw label file and then genrate internal mapping file and keep the info of mapping + DISKANN_DLLEXPORT virtual size_t populate_labels(const std::string &raw_labels_file, + const std::string &raw_universal_label) = 0; + + // save labels, labels_map and universal_label to files + virtual void save(const std::string &save_path, const size_t total_points) = 0; + + // load labels, labels_map and universal_label to filter store variables & returns total number of points + virtual size_t load(const std::string &load_path) = 0; + + // For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of + // raw labels to compute GT correctly. + DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0; + + protected: + // This is for internal use and only loads already parsed file + virtual size_t load_labels(const std::string &labels_file) = 0; + virtual void load_label_map(const std::string &labels_map_file) = 0; + virtual void load_universal_labels(const std::string &universal_labels_file) = 0; + virtual void save_labels(const std::string &save_path, const size_t total_points) = 0; + virtual void save_label_map(const std::string &save_path) = 0; + virtual void save_universal_label(const std::string &save_path) = 0; + + private: + size_t _num_points; + + // mark Index as friend so it can access protected loads + template friend class Index; +}; + +} // namespace diskann \ No newline at end of file diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h new file mode 100644 index 000000000..405235f4c --- /dev/null +++ b/include/in_mem_filter_store.h @@ -0,0 +1,83 @@ +#pragma once +#include + +namespace diskann +{ + +// This class is responsible for filter actions in index, and should not be used outside. +template class InMemFilterStore : public AbstractFilterStore +{ + public: + InMemFilterStore(const size_t num_points); + ~InMemFilterStore() = default; + + // needs some internal lock + bool detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels, + const FilterMatchStrategy filter_match_strategy) override; + + const std::vector &get_labels_by_location(const location_t point_id) override; + // const label_type get_universal_label + + // Dynamic Index + void set_labels_to_location(const location_t location, const std::vector &labels); + void swap_labels(const location_t location_first, const location_t location_second) override; + const tsl::robin_set &get_all_label_set() override; + void add_to_label_set(const label_type &label) override; + // Throws: out of range exception + void add_label_to_location(const location_t point_id, const label_type label) override; + // returns internal mapping for given raw_label + label_type get_numeric_label(const std::string &raw_label) override; + + // takes raw universal labels and map them internally. + void set_universal_label(const std::string &raw_universal_labels) override; + std::pair get_universal_label() override; + + // ideally takes raw label file and then genrate internal mapping file and keep the info of mapping + size_t populate_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) override; + + // save labels, labels_map and universal_label to files + void save(const std::string &save_path, const size_t total_points) override; + + // load labels, labels_map and universal_label to filter store variables & returns total number of points + size_t load(const std::string &load_path) override; + + // For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of + // raw labels to compute GT correctly. + void save_raw_labels(const std::string &save_path, const size_t total_points) override; + + // The function is static so it remains the source of truth across the code. Returns label map + DISKANN_DLLEXPORT static std::unordered_map convert_label_to_numeric( + const std::string &inFileName, const std::string &outFileName, const std::string &mapFileName, + const std::string &raw_universal_labels); + + protected: + // This is for internal use and only loads already parsed file, used by index in during load(). + // populates _loaction_to labels and _labels from given label file + size_t load_labels(const std::string &labels_file) override; + void load_label_map(const std::string &labels_map_file) override; + void load_universal_labels(const std::string &universal_labels_file) override; + void save_labels(const std::string &save_path, const size_t total_points) override; + void save_label_map(const std::string &save_path) override; + void save_universal_label(const std::string &save_path) override; + + private: + size_t _num_points; + std::vector> _location_to_labels; + tsl::robin_set _labels; + std::unordered_map _label_map; + + // universal label + bool _has_universal_label = false; + label_type _universal_label; + + // no need of storing raw universal label ? + // 1. _use_universal_label can be used to identify if universal label present or not + // 2. from _label_map and _mapped_universal_label, we can know what is raw universal label. Hence seems duplicate + // std::string _raw_universal_label; + + bool detect_common_filters_by_set_intersection(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels); +}; + +} // namespace diskann \ No newline at end of file diff --git a/include/index.h b/include/index.h index 199171020..ab5545ef8 100644 --- a/include/index.h +++ b/include/index.h @@ -20,6 +20,7 @@ #include "scratch.h" #include "in_mem_data_store.h" #include "in_mem_graph_store.h" +#include "in_mem_filter_store.h" #include "abstract_index.h" #include "quantized_distance.h" @@ -105,8 +106,8 @@ template clas IndexFilterParams &filter_params); // Filtered Support - DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, + DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const size_t num_points_to_load, + const IndexFilterParams filter_params, const std::vector &tags = std::vector()); DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); @@ -340,6 +341,9 @@ template clas // Graph related data structures std::unique_ptr _graph_store; + // Filter related data structure + std::unique_ptr> _filter_store; + char *_opt_graph = nullptr; // Dimensions diff --git a/include/index_build_params.h b/include/index_build_params.h index 0233fcec4..27d5180c8 100644 --- a/include/index_build_params.h +++ b/include/index_build_params.h @@ -3,69 +3,4 @@ namespace diskann { -struct IndexFilterParams -{ - public: - std::string save_path_prefix; - std::string label_file; - std::string tags_file; - std::string universal_label; - uint32_t filter_threshold = 0; - - private: - IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file, - const std::string &universal_label, uint32_t filter_threshold) - : save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label), - filter_threshold(filter_threshold) - { - } - - friend class IndexFilterParamsBuilder; -}; -class IndexFilterParamsBuilder -{ - public: - IndexFilterParamsBuilder() = default; - - IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) - { - if (save_path_prefix.empty() || save_path_prefix == "") - throw ANNException("Error: save_path_prefix can't be empty", -1); - this->_save_path_prefix = save_path_prefix; - return *this; - } - - IndexFilterParamsBuilder &with_label_file(const std::string &label_file) - { - this->_label_file = label_file; - return *this; - } - - IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label) - { - this->_universal_label = univeral_label; - return *this; - } - - IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold) - { - this->_filter_threshold = filter_threshold; - return *this; - } - - IndexFilterParams build() - { - return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold); - } - - IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete; - IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete; - - private: - std::string _save_path_prefix; - std::string _label_file; - std::string _tags_file; - std::string _universal_label; - uint32_t _filter_threshold = 0; -}; } // namespace diskann diff --git a/include/parameters.h b/include/parameters.h index 2bba9aeca..73cd88aa0 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -117,4 +117,57 @@ class IndexWriteParametersBuilder uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; }; +struct IndexFilterParams +{ + public: + std::string save_path_prefix; + std::string label_file; + std::string universal_label; + + private: + IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file, + const std::string &universal_label) + : save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label) + { + } + + friend class IndexFilterParamsBuilder; +}; +class IndexFilterParamsBuilder +{ + public: + IndexFilterParamsBuilder() = default; + + IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) + { + this->_save_path_prefix = save_path_prefix; + return *this; + } + + IndexFilterParamsBuilder &with_label_file(const std::string &label_file) + { + this->_label_file = label_file; + return *this; + } + + IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label) + { + this->_universal_label = univeral_label; + return *this; + } + + IndexFilterParams build() + { + return IndexFilterParams(_save_path_prefix, _label_file, _universal_label); + } + + IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete; + IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete; + + private: + std::string _save_path_prefix; + std::string _label_file; + std::string _universal_label; +}; + } // namespace diskann diff --git a/python/src/builder.cpp b/python/src/builder.cpp index e02a86d6c..dfc15e33b 100644 --- a/python/src/builder.cpp +++ b/python/src/builder.cpp @@ -98,7 +98,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_ { auto labels_file = prepare_filtered_label_map(index, index_output_path, filter_labels_file, universal_label); - index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags); + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(labels_file) + .with_save_path_prefix(index_output_path) + .build(); + index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params, tags); } } else @@ -111,7 +116,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_ { auto labels_file = prepare_filtered_label_map(index, index_output_path, filter_labels_file, universal_label); - index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num); + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(labels_file) + .with_save_path_prefix(index_output_path) + .build(); + index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params); } } diff --git a/src/abstract_filter_store.cpp b/src/abstract_filter_store.cpp new file mode 100644 index 000000000..c0105e597 --- /dev/null +++ b/src/abstract_filter_store.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "abstract_filter_store.h" + +namespace diskann +{ + +template +AbstractFilterStore::AbstractFilterStore(const size_t num_points) : _num_points(num_points) +{ +} + +template DISKANN_DLLEXPORT class AbstractFilterStore; +template DISKANN_DLLEXPORT class AbstractFilterStore; +} // namespace diskann \ No newline at end of file diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 624eecc65..8ebd36874 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -662,7 +662,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr LabelT unv_label_as_num = 0; _index.set_universal_label(unv_label_as_num); } - _index.build_filtered_index(base_file.c_str(), label_file, base_num); + diskann::IndexFilterParams filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(mem_index_path) + .build(); + _index.build_filtered_index(base_file.c_str(), base_num, filter_params); } _index.save(mem_index_path.c_str()); @@ -736,7 +741,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr LabelT unv_label_as_num = 0; _index.set_universal_label(unv_label_as_num); } - _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts); + diskann::IndexFilterParams filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(shard_labels_file) + .with_save_path_prefix(shard_index_file) + .build(); + _index.build_filtered_index(shard_base_file.c_str(), shard_base_pts, filter_params); } _index.save(shard_index_file.c_str()); // copy universal label file from first shard to the final destination diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 096d1b76e..a46003464 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -4,7 +4,7 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp ../in_mem_filter_store.cpp ../abstract_filter_store.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp new file mode 100644 index 000000000..d02ef680e --- /dev/null +++ b/src/in_mem_filter_store.cpp @@ -0,0 +1,450 @@ +#include + +namespace diskann +{ + +template +InMemFilterStore::InMemFilterStore(const size_t num_points) : AbstractFilterStore(num_points) +{ + _location_to_labels.resize(num_points); +} + +template +bool InMemFilterStore::detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels, + const FilterMatchStrategy filter_match_strategy) +{ + switch (filter_match_strategy) + { + case FilterMatchStrategy::SET_INTERSECTION: + return this->detect_common_filters_by_set_intersection(point_id, search_invocation, incoming_labels); + default: + throw diskann::ANNException("Error: the provided filter match strategy is not supported.", -1); + } +} + +template +const std::vector &InMemFilterStore::get_labels_by_location(const location_t point_id) +{ + return _location_to_labels[point_id]; +} + +template +void InMemFilterStore::set_labels_to_location(const location_t location, + const std::vector &label_str) +{ + std::vector labels; + for (int i = 0; i < label_str.size(); i++) + { + labels.push_back(this->get_numeric_label(label_str[i])); + } + _location_to_labels[location] = labels; +} + +template +void InMemFilterStore::swap_labels(const location_t location_first, const location_t location_second) +{ + _location_to_labels[location_first].swap(_location_to_labels[location_second]); +} + +template const tsl::robin_set &InMemFilterStore::get_all_label_set() +{ + return _labels; +} + +template void InMemFilterStore::add_to_label_set(const label_type &label) +{ + _labels.insert(label); +} + +template +void InMemFilterStore::add_label_to_location(const location_t point_id, const label_type label) +{ + _location_to_labels[point_id].emplace_back(label); +} + +template +void InMemFilterStore::set_universal_label(const std::string &raw_universal_label) +{ + if (raw_universal_label.empty()) + { + std::cout << "Warning: empty universal label passed" << std::endl; + } + else + { + _has_universal_label = true; + _universal_label = _label_map[raw_universal_label]; + } +} + +template std::pair InMemFilterStore::get_universal_label() +{ + std::pair universal_label; + universal_label.second = _universal_label; + if (_has_universal_label) + { + universal_label.first = false; + } + else + { + universal_label.second = false; + } + return universal_label; +} + +// ideally takes raw label file and then genrate internal mapping and keep the info of mapping +template +size_t InMemFilterStore::populate_labels(const std::string &raw_labels_file, + const std::string &raw_universal_label) +{ + std::string raw_label_file_path = + std::string(raw_labels_file).erase(raw_labels_file.size() - 4); // remove .txt from end + // generate a map file + std::string labels_file_to_use = + raw_label_file_path + "_label_numeric.txt"; // will not be used after parse, can be safely deleted. + std::string mem_labels_int_map_file = raw_label_file_path + "_labels_map.txt"; + _label_map = InMemFilterStore::convert_label_to_numeric(raw_labels_file, labels_file_to_use, + mem_labels_int_map_file, raw_universal_label); + return load_labels(labels_file_to_use); +} + +template void InMemFilterStore::load_label_map(const std::string &labels_map_file) +{ + if (file_exists(labels_map_file)) + { + std::ifstream map_reader(labels_map_file); + std::string line, token; + label_type token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = (label_type)std::stoul(token); + _label_map[label_str] = token_as_num; + } + } + else + { + // TODO: throw exception from here and also make sure filtered_index is set appropriately for both build and + // search of index. + diskann::cout << "Warning: Can't load label map file please make sure it was generate, either by " + "filter_store->populate_labels() " + "then index->save() or convert_label_to_numeric() method in case of dynamic index" + << std::endl; + } +} + +template +void InMemFilterStore::load_universal_labels(const std::string &universal_labels_file) +{ + if (file_exists(universal_labels_file)) + { + std::ifstream universal_label_reader(universal_labels_file); + std::string line; + while (std::getline(universal_label_reader, line)) + { + std::istringstream iss(line); + label_type universal_label; + if (!(iss >> universal_label)) + { + throw std::runtime_error("ERROR: Invalid universal label " + line); + } + _universal_label = universal_label; + _has_universal_label = true; + } + universal_label_reader.close(); + } +} + +// load labels, labels_map and universal_label to filter store variables & returns total number of points +template size_t InMemFilterStore::load(const std::string &load_path) +{ + const std::string labels_file = load_path + "_labels.txt"; + const std::string labels_map_file = load_path + "_labels_map.txt"; + const std::string universal_label_file = load_path + "_universal_label.txt"; + load_label_map(labels_map_file); + load_universal_labels(universal_label_file); + return load_labels(labels_file); +} + +template +void InMemFilterStore::save(const std::string &save_path, const size_t total_points) +{ + const std::string label_path = save_path + "_labels.txt"; + const std::string universal_label_path = save_path + "_universal_label.txt"; + const std::string label_map_path = save_path + "_labels_map.txt"; + save_label_map(label_map_path); + save_universal_label(universal_label_path); + save_labels(label_path, total_points); +} + +template +void InMemFilterStore::save_labels(const std::string &save_path, const size_t total_points) +{ + + if (_location_to_labels.size() > 0) + { + std::ofstream label_writer(save_path); + assert(label_writer.is_open()); + for (uint32_t i = 0; i < total_points; i++) + { + for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++) + { + label_writer << _location_to_labels[i][j] << ","; + } + if (_location_to_labels[i].size() != 0) + label_writer << _location_to_labels[i][_location_to_labels[i].size() - 1]; + label_writer << std::endl; + } + label_writer.close(); + } +} + +template +void InMemFilterStore::save_raw_labels(const std::string &save_path, const size_t total_points) +{ + if (_label_map.empty()) + { + diskann::cout << "Warning: not saving raw labels as label map is empty" << std::endl; + return; + } + std::unordered_map mapped_to_raw_labels; + // invert label map + for (const auto &[key, value] : _label_map) + { + mapped_to_raw_labels.insert({value, key}); + } + + // write updated labels + std::ofstream raw_label_writer(save_path); + assert(raw_label_writer.is_open()); + for (uint32_t i = 0; i < total_points; i++) + { + for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++) + { + raw_label_writer << mapped_to_raw_labels[_location_to_labels[i][j]] << ","; + } + if (_location_to_labels[i].size() != 0) + raw_label_writer << mapped_to_raw_labels[_location_to_labels[i][_location_to_labels[i].size() - 1]]; + + raw_label_writer << std::endl; + } + raw_label_writer.close(); +} + +template void InMemFilterStore::save_universal_label(const std::string &save_path) +{ + if (_has_universal_label) + { + std::ofstream universal_label_writer(save_path); + assert(universal_label_writer.is_open()); + // universal_label_writer << _universal_label << std::endl; + + universal_label_writer << _universal_label << std::endl; + universal_label_writer.close(); + } +} + +template void InMemFilterStore::save_label_map(const std::string &save_path) +{ + if (_label_map.empty()) + { + diskann::cout << "Warning: not saving label map as it is empty." << std::endl; + return; + } + std::ofstream map_writer(save_path); + for (auto mp : _label_map) + { + map_writer << mp.first << "\t" << mp.second << std::endl; + } + map_writer.close(); +} + +template label_type InMemFilterStore::get_numeric_label(const std::string &raw_label) +{ + if (_label_map.empty()) + { + throw diskann::ANNException("Error: Label map is empty, please load the map before hand", -1); + } + if (_label_map.find(raw_label) != _label_map.end()) + { + return _label_map[raw_label]; + } + // why is this here + if (_has_universal_label) + { + // Not sure why this is here, but when we start getting more labels chnage this + return _universal_label; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); +} + +template size_t InMemFilterStore::load_labels(const std::string &label_file) +{ + // Format of Label txt file: filters with comma separators + // Format of Label txt file: filters with comma separators + std::ifstream infile(label_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + uint32_t line_cnt = 0; + + // Find total number of points in the labels file to reserve _locations_to_labels + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + break; + cur_pos = next_pos + 1; + line_cnt++; + } + cur_pos = 0; + next_pos = 0; + _location_to_labels.resize(line_cnt, std::vector()); + line_cnt = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + std::vector lbls(0); + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + label_type token_as_num = (label_type)std::stoul(label_str); + lbls.push_back(token_as_num); + _labels.insert(token_as_num); + // move to next label + lbl_pos = next_lbl_pos + 1; + } + cur_pos = next_pos + 1; + _location_to_labels[line_cnt] = lbls; + line_cnt++; + } + diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; + return (size_t)line_cnt; +} + +template +std::unordered_map InMemFilterStore::convert_label_to_numeric( + const std::string &inFileName, const std::string &outFileName, const std::string &mapFileName, + const std::string &raw_universal_label) +{ + std::unordered_map string_int_map; + std::ofstream label_writer(outFileName); + std::ifstream label_reader(inFileName); + std::string line, token; + if (raw_universal_label != "") + string_int_map[raw_universal_label] = 0; // if universal label is provided map it to 0 always + + while (std::getline(label_reader, line)) + { + std::istringstream new_iss(line); + std::vector lbls; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + if (string_int_map.find(token) == string_int_map.end()) + { + uint32_t nextId = (uint32_t)string_int_map.size(); + string_int_map[token] = nextId; + } + lbls.push_back(string_int_map[token]); + } + if (lbls.size() <= 0) + { + std::cout << "No label found"; + exit(-1); + } + for (size_t j = 0; j < lbls.size(); j++) + { + if (j != lbls.size() - 1) + label_writer << lbls[j] << ","; + else + label_writer << lbls[j] << std::endl; + } + } + label_writer.close(); + + std::ofstream map_writer(mapFileName); + for (auto mp : string_int_map) + { + map_writer << mp.first << "\t" << mp.second << std::endl; + } + map_writer.close(); + + return string_int_map; +} + +template +bool InMemFilterStore::detect_common_filters_by_set_intersection( + uint32_t point_id, bool search_invocation, const std::vector &incoming_labels) +{ + auto &curr_node_labels = _location_to_labels[point_id]; + std::set common_filters; + std::set_intersection(incoming_labels.begin(), incoming_labels.end(), curr_node_labels.begin(), + curr_node_labels.end(), std::inserter(common_filters, common_filters.end())); + if (common_filters.size() > 0) + { + // This is to reduce the repetitive calls. If common_filters size is > 0 , + // we dont need to check further for universal label + return true; + } + if (_has_universal_label) + { + if (!search_invocation) + { + if (std::find(incoming_labels.begin(), incoming_labels.end(), _universal_label) != incoming_labels.end() || + std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.insert(_universal_label); + } + else + { + if (std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.insert(_universal_label); + } + } + return (common_filters.size() > 0); +} + +template DISKANN_DLLEXPORT class InMemFilterStore; +template DISKANN_DLLEXPORT class InMemFilterStore; + +} // namespace diskann \ No newline at end of file diff --git a/src/index.cpp b/src/index.cpp index d906600d1..aef1132b3 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -74,6 +74,10 @@ Index::Index(const IndexConfig &index_config, std::shared_ptr>(total_internal_points); + }*/ _locks = std::vector(total_internal_points); if (_enable_tags) @@ -151,6 +155,10 @@ Index::Index(Metric m, const size_t dim, const size_t max_point { _pq_data_store = _data_store; } + /*if (filtered_index) + { + _filter_store = std::make_unique>(max_points + num_frozen_pts); + }*/ } template Index::~Index() @@ -1749,7 +1757,12 @@ void Index::build(const std::string &data_file, const size_t nu LabelT unv_label_as_num = 0; this->set_universal_label(unv_label_as_num); } - this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load); + diskann::IndexFilterParams filter_params_build = diskann::IndexFilterParamsBuilder() + .with_universal_label(filter_params.universal_label) + .with_label_file(labels_file_to_use) + .with_save_path_prefix(filter_params.save_path_prefix) + .build(); + this->build_filtered_index(data_file.c_str(), points_to_load, filter_params_build); } std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; std::cout << "Indexing time: " << diff.count() << "\n"; @@ -1853,14 +1866,14 @@ void Index::set_universal_label(const LabelT &label) } template -void Index::build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, const std::vector &tags) +void Index::build_filtered_index(const char *filename, const size_t num_points_to_load, + const IndexFilterParams filter_params, const std::vector &tags) { _filtered_index = true; _label_to_start_id.clear(); size_t num_points_labels = 0; - parse_label_file(label_file, + parse_label_file(filter_params.label_file, num_points_labels); // determines medoid for each label and identifies // the points to label mapping