Skip to content

Commit

Permalink
Version almost on par with DLVS multi-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr authored and rakri committed Jan 2, 2025
1 parent a78592f commit 392c49f
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 95 deletions.
127 changes: 103 additions & 24 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,101 @@
#endif

#define WARMUP false
#define DISKANN_DEBUG_INDIVIDUAL_RESULTS

namespace po = boost::program_options;

#ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS
void dump_individual_results(uint64_t test_id, uint64_t query_num, uint32_t *gt_ids, float *gt_dists, uint64_t gt_dim,
const std::vector<uint32_t> &query_result_ids,
const std::vector<float> &query_result_dists, uint64_t recall_at,
const std::string &result_output_prefix)
{
uint32_t cumulative_dist_matches = 0;
uint32_t cumulative_id_matches = 0;
std::stringstream results_stream;
std::stringstream per_query_stats_stream;

per_query_stats_stream << "query_id\tid_matches\tdist_matches\ttotal_matches\trecall" << std::endl;
for (int qid = 0; qid < query_num; qid++)
{
results_stream << qid << "\t";
uint32_t per_query_dist_matches = 0;
uint32_t per_query_id_matches = 0;

for (uint64_t i = 0; i < recall_at; i++)
{
auto rindex = qid * recall_at + i;
results_stream << "(" << query_result_ids[rindex] << "," << query_result_dists[rindex] << ",";

bool id_match = false;
bool dist_match = false;
for (uint64_t j = 0; j < recall_at; j++)
{
auto gindex = qid * gt_dim + j;
if (query_result_ids[rindex] == gt_ids[gindex])
{
per_query_id_matches++;
id_match = true;
break;
}
else if (query_result_dists[rindex] / gt_dists[gindex] <= 1.0f)
{
per_query_dist_matches++;
dist_match = true;
break;
}
}
std::string code = "X";
if (id_match)
{
code = "I";
}
else if (dist_match)
{
code = "D";
}
results_stream << code << "),";
}

results_stream << std::endl;

cumulative_id_matches += per_query_id_matches;
cumulative_dist_matches += per_query_dist_matches;
per_query_stats_stream << qid << "\t" << per_query_id_matches << "\t" << per_query_dist_matches << "\t"
<< per_query_id_matches + per_query_dist_matches << "\t"
<< (per_query_id_matches + per_query_dist_matches) * 1.0f / recall_at << std::endl;
}
{

std::string results_file = result_output_prefix + "_L" + std::to_string(test_id) + "_results.tsv";
std::ofstream out(results_file);
out << results_stream.str() << std::endl;
}
{
std::string per_query_stats_file = result_output_prefix + "_L" + std::to_string(test_id) + "_query_stats.tsv";
std::ofstream out(per_query_stats_file);
out << per_query_stats_stream.str() << std::endl;
}
}

void write_gt_to_tsv(const std::string &cur_result_path, uint64_t query_num, uint32_t *gt_ids, float *gt_dists,
uint64_t gt_dim)
{
std::ofstream gt_out(cur_result_path + "_gt.tsv");
for (int i = 0; i < query_num; i++)
{
gt_out << i << "\t";
for (int j = 0; j < gt_dim; j++)
{
gt_out << "(" << gt_ids[i * gt_dim + j] << "," << gt_dists[i * gt_dim + j] << "),";
}
gt_out << std::endl;
}
}
#endif


void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
Expand All @@ -54,7 +146,7 @@ void parse_labels_of_query(const std::string &filters_for_query,
std::vector<LabelT> &label_ids_for_query)
{
std::vector<std::string> label_strs_for_query;
diskann::split_string(filters_for_query, MULTIPLE_LABEL_SEPARATOR, label_strs_for_query);
diskann::split_string(filters_for_query, FILTER_OR_SEPARATOR, label_strs_for_query);
for (auto &label_str_for_query : label_strs_for_query)
{
label_ids_for_query.push_back(pFlashIndex->get_converted_label(label_str_for_query));
Expand All @@ -70,7 +162,7 @@ void populate_label_ids(const std::vector<std::string> &filters_of_queries,
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_queries[0], pFlashIndex, label_ids_of_query);
for (auto i = 0; i < query_count; i++)
for (uint32_t i = 0; i < query_count; i++)
{
label_ids_of_queries.push_back(label_ids_of_query);
}
Expand Down Expand Up @@ -318,6 +410,10 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
query_result_ids[test_id].data(), recall_at, recall_at);
best_recall = std::max(recall, best_recall);
}
#ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS
dump_individual_results(test_id, query_num, gt_ids, gt_dists, gt_dim, query_result_ids[test_id],
query_result_dists[test_id], recall_at, result_output_prefix);
#endif

diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
Expand All @@ -327,31 +423,14 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
diskann::cout << std::setw(16) << recall << std::endl ;
}
else
{
diskann::cout << std::endl;

//std::stringstream rslts_string;
//for (auto x = 0; x < query_num; x++)
//{
// rslts_string << "-----------------------------------------" << std::endl;
// rslts_string << "Query: " << x << std::endl;
// rslts_string << "GT: {";
// for (auto rx = 0; rx < recall_at; rx++)
// {
// rslts_string << "(" << gt_ids[x* gt_dim + rx] << "," << gt_dists[x * gt_dim + rx] << "), ";
// }
// rslts_string << "}" << std::endl;
// rslts_string << "Results: {";
// for (auto rx = 0; rx < recall_at; rx++)
// {
// rslts_string << "(" << query_result_ids[test_id][x * recall_at + rx] << ","
// << query_result_dists[test_id][x * recall_at + rx] << "), ";
// }
// rslts_string << "}" << std::endl;
// rslts_string << "-----------------------------------------" << std::endl;
//}
//diskann::cout << rslts_string.str() << std::endl;
}
delete[] stats;
}
#ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS
write_gt_to_tsv(result_output_prefix, query_num, gt_ids, gt_dists, gt_dim);
#endif

diskann::cout << "Done searching. Now saving results " << std::endl;
uint64_t test_id = 0;
Expand Down
5 changes: 0 additions & 5 deletions include/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ class NeighborPriorityQueue
return _cur < _size;
}

void sort()
{
std::sort(_data.begin(), _data.begin() + _size);
}

size_t size() const
{
return _size;
Expand Down
32 changes: 21 additions & 11 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
#include "tsl/robin_set.h"

#define FULL_PRECISION_REORDER_MULTIPLIER 3
#define DEFAULT_VISITED_RESERVE_SIZE 4096
//default max filters per query is set to the same
//as what we expect Bing to provide. If this is overkill,
//it can be set by clients in the load() function
#define DEFAULT_MAX_FILTERS_PER_QUERY 4096

namespace diskann
{
Expand All @@ -30,24 +35,28 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT ~PQFlashIndex();

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix);
DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix,
uint32_t max_filters_per_query = DEFAULT_MAX_FILTERS_PER_QUERY);
#else
// load compressed data, and obtains the handle to the disk-resident index
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix,
uint32_t max_filters_per_query = DEFAULT_MAX_FILTERS_PER_QUERY);
#endif

DISKANN_DLLEXPORT void load_labels(const std::string& disk_index_filepath);
DISKANN_DLLEXPORT void load_label_medoid_map(
const std::string &labels_to_medoids_filepath, std::istream &medoid_stream);
DISKANN_DLLEXPORT void load_dummy_map(const std::string& dummy_map_filepath, std::istream &dummy_map_stream);
DISKANN_DLLEXPORT void load_labels(const std::string &disk_index_filepath);
DISKANN_DLLEXPORT void load_label_medoid_map(const std::string &labels_to_medoids_filepath,
std::istream &medoid_stream);
DISKANN_DLLEXPORT void load_dummy_map(const std::string &dummy_map_filepath, std::istream &dummy_map_stream);

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
const char *compressed_filepath);
const char *compressed_filepath,
uint32_t max_filters_per_query);
#else
DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath,
const char *pivots_filepath, const char *compressed_filepath);
const char *pivots_filepath, const char *compressed_filepath,
uint32_t max_filters_per_query);
#endif

DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);
Expand Down Expand Up @@ -116,7 +125,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = DEFAULT_VISITED_RESERVE_SIZE,
uint64_t max_filters_per_query = DEFAULT_MAX_FILTERS_PER_QUERY);

DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);

Expand Down Expand Up @@ -189,7 +199,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
// chunk_size = chunk size of each dimension chunk
// pq_tables = float* [[2^8 * [chunk_size]] * _n_chunks]
uint8_t *data = nullptr;
uint64_t _n_chunks;
uint64_t _n_chunks = 0;
FixedChunkPQTable _pq_table;

// distance comparator
Expand All @@ -207,7 +217,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
// we can optionally have multiple starting points
uint32_t *_medoids = nullptr;
// defaults to 1
size_t _num_medoids;
size_t _num_medoids = 1;
// by default, it is empty. If there are multiple
// centroids, we pick the medoid corresponding to the
// closest centroid as the starting point of search
Expand Down
4 changes: 2 additions & 2 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ template <typename T> class SSDQueryScratch : public AbstractScratch<T>
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;

SSDQueryScratch(size_t aligned_dim, size_t visited_reserve);
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_filters_per_query);
~SSDQueryScratch();

void reset();
Expand All @@ -162,7 +162,7 @@ template <typename T> class SSDThreadData
SSDQueryScratch<T> scratch;
IOContext ctx;

SSDThreadData(size_t aligned_dim, size_t visited_reserve);
SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_filters_per_query);
void clear();
};

Expand Down
2 changes: 1 addition & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ typedef int FileHandle;

#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
#define PBWIDTH 60
#define MULTIPLE_LABEL_SEPARATOR "|"
#define FILTER_OR_SEPARATOR "|"

inline bool file_exists_impl(const std::string &name, bool dirCheck = false)
{
Expand Down
Loading

0 comments on commit 392c49f

Please sign in to comment.