Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
REDMOND\ninchen committed May 29, 2024
1 parent cb455f9 commit d053e33
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 33 deletions.
2 changes: 1 addition & 1 deletion apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
.is_enable_tags(tags)
.is_concurrent_consolidate(false)
.is_pq_dist_build(use_pq_build)
.is_use_opq(use_pq_build)
.is_use_opq(use_opq)
.with_num_pq_chunks(pq_num_chunks)
.with_num_frozen_pts(num_frozen_pts)
.with_pq_codebook_path(codebook_file)
Expand Down
4 changes: 3 additions & 1 deletion include/fixed_chunk_pq_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ namespace diskann
float l2_distance(const float *query_vec, uint8_t *base_vec);
float inner_product(const float *query_vec, uint8_t *base_vec);
// assumes no rotation is involved
void inflate_vector(uint8_t *base_vec, float *out_vec);
template <typename InputType = uint8_t, typename OutputType = float>
void inflate_vector(InputType *base_vec, OutputType *out_vec) const;

void populate_chunk_inner_products(const float *query_vec, float *dist_vec);

float *tables = nullptr; // pq_tables = float array of size [256 * ndims]
Expand Down
1 change: 1 addition & 0 deletions include/pq_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>

Metric _distance_metric;
std::unique_ptr<Distance<data_t>> _distance_fn = nullptr;
std::string _pq_pivot_file_path;
std::unique_ptr<QuantizedDistance<data_t>> _pq_distance_fn = nullptr;
};
} // namespace diskann
2 changes: 1 addition & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ void convert_types(const InType *srcmat, OutType *destmat, size_t npts, size_t d
{
for (uint64_t j = 0; j < dim; j++)
{
destmat[i * dim + j] = (OutType)srcmat[i * dim + j];
destmat[i * dim + j] = static_cast<OutType>(srcmat[i * dim + j]);
}
}
}
Expand Down
12 changes: 9 additions & 3 deletions src/fixed_chunk_pq_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,15 @@ namespace diskann
}

// assumes no rotation is involved
void FixedChunkPQTable::inflate_vector(uint8_t *base_vec, float *out_vec)
template <typename InputType, typename OutputType>
void FixedChunkPQTable::inflate_vector(InputType *base_vec, OutputType *out_vec) const
{
for (size_t chunk = 0; chunk < n_chunks; chunk++)
{
for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++)
{
const float *centers_dim_vec = tables_tr + (256 * j);
out_vec[j] = centers_dim_vec[base_vec[chunk]] + centroid[j];
out_vec[j] = static_cast<OutputType> (centers_dim_vec[static_cast<uint8_t>(base_vec[chunk])] + centroid[j]);
}
}
}
Expand Down Expand Up @@ -263,4 +264,9 @@ namespace diskann
}
}
}
} // namespace diskann

template void FixedChunkPQTable::inflate_vector<uint8_t, float>(uint8_t *base_vec, float *out_vec) const;
template void FixedChunkPQTable::inflate_vector<uint8_t, uint8_t>(uint8_t *base_vec, uint8_t *out_vec) const;
template void FixedChunkPQTable::inflate_vector<int8_t, int8_t>(int8_t *base_vec, int8_t *out_vec) const;
template void FixedChunkPQTable::inflate_vector<float, float>(float *base_vec, float *out_vec) const;
} // namespace diskann
21 changes: 5 additions & 16 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L

if (!use_filter)
{
_data_store->get_vector(location, scratch->aligned_query());
_pq_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);
}
else
Expand All @@ -1041,7 +1041,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
if (_dynamic_index)
tl.unlock();

_data_store->get_vector(location, scratch->aligned_query());
_pq_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true,
_location_to_labels[location], false);

Expand All @@ -1055,7 +1055,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
// clear scratch for finding unfiltered candidates
scratch->clear();

_data_store->get_vector(location, scratch->aligned_query());
_pq_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);

for (auto unfiltered_neighbour : scratch->pool())
Expand Down Expand Up @@ -1622,11 +1622,6 @@ void Index<T, TagT, LabelT>::build(const T *data, const size_t num_points_to_loa
{
throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__);
}
if (_pq_dist)
{
throw ANNException("ERROR: DO not use this build interface with PQ distance", -1, __FUNCSIG__, __FILE__,
__LINE__);
}

std::unique_lock<std::shared_timed_mutex> ul(_update_lock);

Expand All @@ -1635,6 +1630,7 @@ void Index<T, TagT, LabelT>::build(const T *data, const size_t num_points_to_loa
_nd = num_points_to_load;

_data_store->populate_data(data, (location_t)num_points_to_load);
_pq_data_store->populate_data(data, (location_t)num_points_to_load);
}

build_with_data_populated(tags);
Expand Down Expand Up @@ -2236,14 +2232,7 @@ size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K

if (res_vectors.size() > 0)
{
if (_pq_dist)
{
_pq_data_store->get_vector(node.id, res_vectors[pos]);
}
else
{
_data_store->get_vector(node.id, res_vectors[pos]);
}
_pq_data_store->get_vector(node.id, res_vectors[pos]);
}

if (distances != nullptr)
Expand Down
21 changes: 10 additions & 11 deletions src/pq_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ PQDataStore<data_t>::PQDataStore(size_t dim, location_t num_points, size_t num_p
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn,
const std::string &codebook_path)
#endif
: AbstractDataStore<data_t>(num_points, num_pq_chunks), _num_chunks(num_pq_chunks)
: AbstractDataStore<data_t>(num_points, num_pq_chunks), _num_chunks(num_pq_chunks),
_pq_pivot_file_path(codebook_path)
{
if (num_pq_chunks > dim)
{
Expand Down Expand Up @@ -99,7 +100,10 @@ template <typename data_t> void PQDataStore<data_t>::populate_data(const std::st

double p_val = std::min(1.0, ((double)MAX_PQ_TRAINING_SET_SIZE / (double)file_num_points));

auto pivots_file = get_pivot_data_filename(filename, _use_opq, static_cast<uint32_t>(_num_chunks));
auto pivots_file = _pq_pivot_file_path.empty()
? get_pivot_data_filename(filename, _use_opq, static_cast<uint32_t>(_num_chunks))
: _pq_pivot_file_path;

auto compressed_file = get_quantized_vectors_filename(filename, _use_opq, static_cast<uint32_t>(_num_chunks));

generate_quantized_data<data_t>(filename, pivots_file, compressed_file, _distance_fn->get_metric(), p_val,
Expand Down Expand Up @@ -130,7 +134,8 @@ template <typename data_t> void PQDataStore<data_t>::get_vector(const location_t
// REFACTOR TODO: Should we inflate the compressed vector here?
if (i < this->capacity())
{
memcpy(dest, _quantized_data + i * _aligned_dim, this->_dim * sizeof(data_t));
const FixedChunkPQTable &pq_table = _pq_distance_fn->get_pq_table();
pq_table.inflate_vector<data_t, data_t>((data_t *)(_quantized_data + i * _aligned_dim), dest);
}
else
{
Expand All @@ -157,7 +162,6 @@ template <typename data_t> void PQDataStore<data_t>::set_vector(const location_t
uint64_t num_chunks = _num_chunks;

std::vector<float> vector_float(full_dimension);

diskann::convert_types<data_t, float>(vector, vector_float.data(), 1, full_dimension);
std::vector<uint8_t> compressed_vector(num_chunks * sizeof(data_t));
std::vector<data_t> compressed_vector_T(num_chunks);
Expand Down Expand Up @@ -249,17 +253,12 @@ void PQDataStore<data_t>::preprocess_query(const data_t *aligned_query, Abstract

template <typename data_t> float PQDataStore<data_t>::get_distance(const data_t *query, const location_t loc) const
{
// Probably should return PQ distance.
return _distance_fn->compare(query, reinterpret_cast<data_t *>(_quantized_data) + _aligned_dim * loc,
(uint32_t)_aligned_dim);
throw diskann::ANNException("get_distance(const data_t *query, const location_t loc) hasn't been implemented for PQDataStore", -1);
}

template <typename data_t> float PQDataStore<data_t>::get_distance(const location_t loc1, const location_t loc2) const
{
// Probably should return PQ distance.
return _distance_fn->compare(reinterpret_cast<data_t *>(_quantized_data) + loc1 * _aligned_dim,
reinterpret_cast<data_t *>(_quantized_data) + loc2 * _aligned_dim,
(uint32_t)this->_aligned_dim);
throw diskann::ANNException("get_distance(const location_t loc1, const location_t loc2) hasn't been implemented for PQDataStore", -1);
}

template <typename data_t>
Expand Down

0 comments on commit d053e33

Please sign in to comment.