Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FilterStore: unifying filter specific logic #508

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions include/abstract_filter_store.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once
#include "common_includes.h"
#include "utils.h"
#include <any>

namespace diskann
{

enum class FilterMatchStrategy
{
SET_INTERSECTION
};
// This class is responsible for filter actions in index, and should not be used outside.
template <typename label_type> class AbstractFilterStore
{
public:
AbstractFilterStore(const size_t num_points);
virtual ~AbstractFilterStore() = default;

// needs some internal lock + abstract implementation
virtual bool detect_common_filters(uint32_t point_id, bool search_invocation,
const std::vector<label_type> &incoming_labels,
const FilterMatchStrategy strategy = FilterMatchStrategy::SET_INTERSECTION) = 0;

DISKANN_DLLEXPORT virtual const std::vector<label_type> &get_labels_by_location(const location_t point_id) = 0;
virtual void set_labels_to_location(const location_t location, const std::vector<std::string> &labels) = 0;
virtual void swap_labels(const location_t location_first, const location_t location_second) = 0;

virtual const tsl::robin_set<label_type> &get_all_label_set() = 0;
virtual void add_to_label_set(const label_type &label) = 0;
// Throws: out of range exception
virtual void add_label_to_location(const location_t point_id, const label_type label) = 0;
// returns internal mapping for given raw_label
DISKANN_DLLEXPORT virtual label_type get_numeric_label(const std::string &raw_label) = 0;

// TODO: in future we may accept a set or vector of universal labels
// DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0;
virtual void set_universal_label(const std::string &universal_labels) = 0;
virtual std::pair<bool, label_type> get_universal_label() = 0;

// takes raw label file and then genrate internal mapping file and keep the info of mapping
DISKANN_DLLEXPORT virtual size_t populate_labels(const std::string &raw_labels_file,
const std::string &raw_universal_label) = 0;

// save labels, labels_map and universal_label to files
virtual void save(const std::string &save_path, const size_t total_points) = 0;

// load labels, labels_map and universal_label to filter store variables & returns total number of points
virtual size_t load(const std::string &load_path) = 0;

// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
// raw labels to compute GT correctly.
DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0;

protected:
// This is for internal use and only loads already parsed file
virtual size_t load_labels(const std::string &labels_file) = 0;
virtual void load_label_map(const std::string &labels_map_file) = 0;
virtual void load_universal_labels(const std::string &universal_labels_file) = 0;
virtual void save_labels(const std::string &save_path, const size_t total_points) = 0;
virtual void save_label_map(const std::string &save_path) = 0;
virtual void save_universal_label(const std::string &save_path) = 0;

private:
size_t _num_points;

// mark Index as friend so it can access protected loads
template <typename T, typename TagT, typename LabelT> friend class Index;
};

} // namespace diskann
83 changes: 83 additions & 0 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#pragma once
#include <abstract_filter_store.h>

namespace diskann
{

// This class is responsible for filter actions in index, and should not be used outside.
template <typename label_type> class InMemFilterStore : public AbstractFilterStore<label_type>
{
public:
InMemFilterStore(const size_t num_points);
~InMemFilterStore() = default;

// needs some internal lock
bool detect_common_filters(uint32_t point_id, bool search_invocation,
const std::vector<label_type> &incoming_labels,
const FilterMatchStrategy filter_match_strategy) override;

const std::vector<label_type> &get_labels_by_location(const location_t point_id) override;
// const label_type get_universal_label

// Dynamic Index
void set_labels_to_location(const location_t location, const std::vector<std::string> &labels);
void swap_labels(const location_t location_first, const location_t location_second) override;
const tsl::robin_set<label_type> &get_all_label_set() override;
void add_to_label_set(const label_type &label) override;
// Throws: out of range exception
void add_label_to_location(const location_t point_id, const label_type label) override;
// returns internal mapping for given raw_label
label_type get_numeric_label(const std::string &raw_label) override;

// takes raw universal labels and map them internally.
void set_universal_label(const std::string &raw_universal_labels) override;
std::pair<bool, label_type> get_universal_label() override;

// ideally takes raw label file and then genrate internal mapping file and keep the info of mapping
size_t populate_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) override;

// save labels, labels_map and universal_label to files
void save(const std::string &save_path, const size_t total_points) override;

// load labels, labels_map and universal_label to filter store variables & returns total number of points
size_t load(const std::string &load_path) override;

// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
// raw labels to compute GT correctly.
void save_raw_labels(const std::string &save_path, const size_t total_points) override;

// The function is static so it remains the source of truth across the code. Returns label map
DISKANN_DLLEXPORT static std::unordered_map<std::string, label_type> convert_label_to_numeric(
const std::string &inFileName, const std::string &outFileName, const std::string &mapFileName,
const std::string &raw_universal_labels);

protected:
// This is for internal use and only loads already parsed file, used by index in during load().
// populates _loaction_to labels and _labels from given label file
size_t load_labels(const std::string &labels_file) override;
void load_label_map(const std::string &labels_map_file) override;
void load_universal_labels(const std::string &universal_labels_file) override;
void save_labels(const std::string &save_path, const size_t total_points) override;
void save_label_map(const std::string &save_path) override;
void save_universal_label(const std::string &save_path) override;

private:
size_t _num_points;
std::vector<std::vector<label_type>> _location_to_labels;
tsl::robin_set<label_type> _labels;
std::unordered_map<std::string, label_type> _label_map;

// universal label
bool _has_universal_label = false;
label_type _universal_label;

// no need of storing raw universal label ?
// 1. _use_universal_label can be used to identify if universal label present or not
// 2. from _label_map and _mapped_universal_label, we can know what is raw universal label. Hence seems duplicate
// std::string _raw_universal_label;

bool detect_common_filters_by_set_intersection(uint32_t point_id, bool search_invocation,
const std::vector<label_type> &incoming_labels);
};

} // namespace diskann
8 changes: 6 additions & 2 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "scratch.h"
#include "in_mem_data_store.h"
#include "in_mem_graph_store.h"
#include "in_mem_filter_store.h"
#include "abstract_index.h"

#include "quantized_distance.h"
Expand Down Expand Up @@ -105,8 +106,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
IndexFilterParams &filter_params);

// Filtered Support
DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file,
const size_t num_points_to_load,
DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const size_t num_points_to_load,
const IndexFilterParams filter_params,
const std::vector<TagT> &tags = std::vector<TagT>());

DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
Expand Down Expand Up @@ -340,6 +341,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Graph related data structures
std::unique_ptr<AbstractGraphStore> _graph_store;

// Filter related data structure
std::unique_ptr<AbstractFilterStore<LabelT>> _filter_store;

char *_opt_graph = nullptr;

// Dimensions
Expand Down
65 changes: 0 additions & 65 deletions include/index_build_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,4 @@

namespace diskann
{
struct IndexFilterParams
{
public:
std::string save_path_prefix;
std::string label_file;
std::string tags_file;
std::string universal_label;
uint32_t filter_threshold = 0;

private:
IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file,
const std::string &universal_label, uint32_t filter_threshold)
: save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label),
filter_threshold(filter_threshold)
{
}

friend class IndexFilterParamsBuilder;
};
class IndexFilterParamsBuilder
{
public:
IndexFilterParamsBuilder() = default;

IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix)
{
if (save_path_prefix.empty() || save_path_prefix == "")
throw ANNException("Error: save_path_prefix can't be empty", -1);
this->_save_path_prefix = save_path_prefix;
return *this;
}

IndexFilterParamsBuilder &with_label_file(const std::string &label_file)
{
this->_label_file = label_file;
return *this;
}

IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label)
{
this->_universal_label = univeral_label;
return *this;
}

IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold)
{
this->_filter_threshold = filter_threshold;
return *this;
}

IndexFilterParams build()
{
return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold);
}

IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete;
IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete;

private:
std::string _save_path_prefix;
std::string _label_file;
std::string _tags_file;
std::string _universal_label;
uint32_t _filter_threshold = 0;
};
} // namespace diskann
53 changes: 53 additions & 0 deletions include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,57 @@ class IndexWriteParametersBuilder
uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE};
};

struct IndexFilterParams
{
public:
std::string save_path_prefix;
std::string label_file;
std::string universal_label;

private:
IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file,
const std::string &universal_label)
: save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label)
{
}

friend class IndexFilterParamsBuilder;
};
class IndexFilterParamsBuilder
{
public:
IndexFilterParamsBuilder() = default;

IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix)
{
this->_save_path_prefix = save_path_prefix;
return *this;
}

IndexFilterParamsBuilder &with_label_file(const std::string &label_file)
{
this->_label_file = label_file;
return *this;
}

IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label)
{
this->_universal_label = univeral_label;
return *this;
}

IndexFilterParams build()
{
return IndexFilterParams(_save_path_prefix, _label_file, _universal_label);
}

IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete;
IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete;

private:
std::string _save_path_prefix;
std::string _label_file;
std::string _universal_label;
};

} // namespace diskann
14 changes: 12 additions & 2 deletions python/src/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
{
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags);
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(labels_file)
.with_save_path_prefix(index_output_path)
.build();
index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params, tags);
}
}
else
Expand All @@ -111,7 +116,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
{
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num);
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(labels_file)
.with_save_path_prefix(index_output_path)
.build();
index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params);
}
}

Expand Down
16 changes: 16 additions & 0 deletions src/abstract_filter_store.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "abstract_filter_store.h"

namespace diskann
{

template <typename label_type>
AbstractFilterStore<label_type>::AbstractFilterStore(const size_t num_points) : _num_points(num_points)
{
}

template DISKANN_DLLEXPORT class AbstractFilterStore<uint16_t>;
template DISKANN_DLLEXPORT class AbstractFilterStore<uint32_t>;
} // namespace diskann
14 changes: 12 additions & 2 deletions src/disk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
LabelT unv_label_as_num = 0;
_index.set_universal_label(unv_label_as_num);
}
_index.build_filtered_index(base_file.c_str(), label_file, base_num);
diskann::IndexFilterParams filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(label_file)
.with_save_path_prefix(mem_index_path)
.build();
_index.build_filtered_index(base_file.c_str(), base_num, filter_params);
}
_index.save(mem_index_path.c_str());

Expand Down Expand Up @@ -736,7 +741,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
LabelT unv_label_as_num = 0;
_index.set_universal_label(unv_label_as_num);
}
_index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts);
diskann::IndexFilterParams filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(shard_labels_file)
.with_save_path_prefix(shard_index_file)
.build();
_index.build_filtered_index(shard_base_file.c_str(), shard_base_pts, filter_params);
}
_index.save(shard_index_file.c_str());
// copy universal label file from first shard to the final destination
Expand Down
2 changes: 1 addition & 1 deletion src/dll/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp
../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp
../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp)
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp ../in_mem_filter_store.cpp ../abstract_filter_store.cpp)

set(TARGET_DIR "$<$<CONFIG:Debug>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$<CONFIG:Release>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>")

Expand Down
Loading
Loading