Skip to content

Commit

Permalink
Modify for review
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 9, 2024
1 parent 76e311f commit f7900fd
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 126 deletions.
93 changes: 50 additions & 43 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

namespace redis {

Node::Node(NodeKey key, uint16_t level) : key(std::move(key)), level(level) {}
HnswNode::HnswNode(NodeKey key, uint16_t level) : key(std::move(key)), level(level) {}

StatusOr<HnswNodeFieldMetadata> Node::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const {
StatusOr<HnswNodeFieldMetadata> HnswNode::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const {
auto node_index_key = search_key.ConstructHnswNode(level, key);
rocksdb::PinnableSlice value;
auto s = storage->Get(rocksdb::ReadOptions(), storage->GetCFHandle(ColumnFamilyID::Search), node_index_key, &value);
Expand All @@ -48,14 +48,14 @@ StatusOr<HnswNodeFieldMetadata> Node::DecodeMetadata(const SearchKey& search_key
return metadata;
}

void Node::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
void HnswNode::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage,
rocksdb::WriteBatchBase* batch) const {
std::string updated_metadata;
node_meta->Encode(&updated_metadata);
batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), search_key.ConstructHnswNode(level, key), updated_metadata);
}

void Node::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) {
void HnswNode::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) {
neighbours.clear();
auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key);
util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search);
Expand All @@ -71,8 +71,8 @@ void Node::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storag
}
}

Status Node::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
Status HnswNode::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
rocksdb::WriteBatchBase* batch) const {
auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key);
batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key, Slice());

Expand All @@ -82,8 +82,8 @@ Status Node::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_
return Status::OK();
}

Status Node::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
Status HnswNode::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
rocksdb::WriteBatchBase* batch) const {
auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key);
auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key);
if (!s.ok()) {
Expand Down Expand Up @@ -187,8 +187,12 @@ StatusOr<std::vector<VectorItem>> HnswIndex::DecodeNodesToVectorItems(const std:
vector_items.reserve(node_keys.size());

for (const auto& neighbour_key : node_keys) {
Node neighbour_node(neighbour_key, level);
auto neighbour_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
HnswNode neighbour_node(neighbour_key, level);
auto neighbour_metadata_status = neighbour_node.DecodeMetadata(search_key, storage);
if (!neighbour_metadata_status.IsOK()) {
continue; // Skip this neighbour if metadata can't be decoded
}
auto neighbour_metadata = neighbour_metadata_status.GetValue();
vector_items.emplace_back(VectorItem(neighbour_key, std::move(neighbour_metadata.vector), metadata));
}
return vector_items;
Expand Down Expand Up @@ -257,7 +261,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
std::priority_queue<std::pair<double, VectorItem>> result_heap;

for (const auto& entry_point_key : entry_points) {
Node entry_node = Node(entry_point_key, level);
HnswNode entry_node = HnswNode(entry_point_key, level);
auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key, storage));

auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata);
Expand All @@ -269,13 +273,13 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
}

while (!explore_heap.empty()) {
auto [dist, current_vector] = explore_heap.top();
auto [dist, current_vector] = std::move(explore_heap.top());
explore_heap.pop();
if (dist > result_heap.top().first) {
break;
}

auto current_node = Node(current_vector.key, level);
auto current_node = HnswNode(current_vector.key, level);
current_node.DecodeNeighbours(search_key, storage);

for (const auto& neighbour_key : current_node.neighbours) {
Expand All @@ -284,7 +288,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
}
visited.insert(neighbour_key);

auto neighbour_node = Node(neighbour_key, level);
auto neighbour_node = HnswNode(neighbour_key, level);
auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata);

Expand Down Expand Up @@ -327,7 +331,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
for (; level >= 0; level--) {
nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata->ef_construction, entry_points));
auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
auto node = Node(std::string(key), level);
auto node = HnswNode(std::string(key), level);
auto m_max = level == 0 ? 2 * metadata->m : metadata->m;

std::unordered_set<NodeKey> connected_edges_set;
Expand All @@ -339,7 +343,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
};

// Check if candidate node has room after some other nodes' are pruned in current batch
auto has_room_after_deletions = [&](const Node& candidate_node, uint16_t candidate_node_num_neighbours) {
auto has_room_after_deletions = [&](const HnswNode& candidate_node, uint16_t candidate_node_num_neighbours) {
auto it = deleted_edges_map.find(candidate_node.key);
if (it != deleted_edges_map.end()) {
auto num_deleted_edges = static_cast<uint16_t>(it->second.size());
Expand All @@ -349,7 +353,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
};

for (const auto& candidate_vec : candidate_vec_items) {
auto candidate_node = Node(candidate_vec.key, level);
auto candidate_node = HnswNode(candidate_vec.key, level);
auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(search_key, storage));
uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours;

Expand All @@ -365,27 +369,30 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
auto candidate_node_neighbour_vec_items =
GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours, level, search_key, storage, metadata));
candidate_node_neighbour_vec_items.push_back(inserted_vector_item);
auto sorted_neighbours = GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level));
auto sorted_neighbours_by_distance =
GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level));

bool inserted_node_is_selected =
std::find(sorted_neighbours_by_distance.begin(), sorted_neighbours_by_distance.end(),
inserted_vector_item) != sorted_neighbours_by_distance.end();

bool inserted_node_is_selected = std::find(sorted_neighbours.begin(), sorted_neighbours.end(),
inserted_vector_item) != sorted_neighbours.end();
if (inserted_node_is_selected) {
// Add the edge between candidate and inserted node
GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
connected_edges_set.insert(candidate_node.key);

auto find_deleted_item = [&](const std::vector<VectorItem>& candidate_node_neighbour_vec_items,
const std::vector<VectorItem>& sorted_neighbours) -> VectorItem {
auto it = std::find_if(candidate_node_neighbour_vec_items.begin(), candidate_node_neighbour_vec_items.end(),
[&](const VectorItem& item) {
return std::find(sorted_neighbours.begin(), sorted_neighbours.end(), item) ==
sorted_neighbours.end();
});
auto find_deleted_item = [&](const std::vector<VectorItem>& candidate_neighbours,
const std::vector<VectorItem>& selected_neighbours) -> VectorItem {
auto it =
std::find_if(candidate_neighbours.begin(), candidate_neighbours.end(), [&](const VectorItem& item) {
return std::find(selected_neighbours.begin(), selected_neighbours.end(), item) ==
selected_neighbours.end();
});
return *it;
};

// Remove the edge for candidate and the pruned node
auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours);
auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours_by_distance);
GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level, batch));
deleted_edges_map[candidate_node.key].insert(deleted_node.key);
deleted_edges_map[deleted_node.key].insert(candidate_node.key);
Expand All @@ -394,27 +401,27 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu

// Update inserted node metadata
HnswNodeFieldMetadata node_metadata(static_cast<uint16_t>(connected_edges_set.size()), vector);
node.PutMetadata(&node_metadata, search_key, storage, batch);
node.PutMetadata(&node_metadata, search_key, storage, batch.Get());

// Update modified nodes metadata
for (const auto& node_edges : deleted_edges_map) {
auto& current_node_key = node_edges.first;
auto current_node = Node(current_node_key, level);
auto current_node = HnswNode(current_node_key, level);
auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage));
auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size();
if (connected_edges_set.count(current_node_key) != 0) {
new_num_neighbours++;
connected_edges_set.erase(current_node_key);
}
current_node_metadata.num_neighbours = new_num_neighbours;
current_node.PutMetadata(&current_node_metadata, search_key, storage, batch);
current_node.PutMetadata(&current_node_metadata, search_key, storage, batch.Get());
}

for (const auto& current_node_key : connected_edges_set) {
auto current_node = Node(current_node_key, level);
auto current_node = HnswNode(current_node_key, level);
HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage));
current_node_metadata.num_neighbours++;
current_node.PutMetadata(&current_node_metadata, search_key, storage, batch);
current_node.PutMetadata(&current_node_metadata, search_key, storage, batch.Get());
}

entry_points.clear();
Expand All @@ -423,16 +430,16 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
}
}
} else {
auto node = Node(std::string(key), 0);
auto node = HnswNode(std::string(key), 0);
HnswNodeFieldMetadata node_metadata(0, vector);
node.PutMetadata(&node_metadata, search_key, storage, batch);
node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
metadata->num_levels = 1;
}

while (target_level > metadata->num_levels - 1) {
auto node = Node(std::string(key), metadata->num_levels);
auto node = HnswNode(std::string(key), metadata->num_levels);
HnswNodeFieldMetadata node_metadata(0, vector);
node.PutMetadata(&node_metadata, search_key, storage, batch);
node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
metadata->num_levels++;
}

Expand All @@ -444,16 +451,16 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
return Status::OK();
}

Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
Status HnswIndex::InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto target_level = RandomizeLayer();
return InsertVectorEntryInternal(key, std::move(vector), batch, target_level);
return InsertVectorEntryInternal(key, vector, batch, target_level);
}

Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
std::string node_key(key);
for (uint16_t level = 0; level < metadata->num_levels; level++) {
auto node = Node(node_key, level);
auto node = HnswNode(node_key, level);
auto node_metadata_status = node.DecodeMetadata(search_key, storage);
if (!node_metadata_status.IsOK()) {
break;
Expand All @@ -469,10 +476,10 @@ Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<ro
node.DecodeNeighbours(search_key, storage);
for (const auto& neighbour_key : node.neighbours) {
GET_OR_RET(RemoveEdge(node_key, neighbour_key, level, batch));
auto neighbour_node = Node(neighbour_key, level);
auto neighbour_node = HnswNode(neighbour_key, level);
HnswNodeFieldMetadata neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
neighbour_node_metadata.num_neighbours--;
neighbour_node.PutMetadata(&neighbour_node_metadata, search_key, storage, batch);
neighbour_node.PutMetadata(&neighbour_node_metadata, search_key, storage, batch.Get());
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,29 @@ namespace redis {

class HnswIndex;

struct Node {
struct HnswNode {
using NodeKey = std::string;
NodeKey key;
uint16_t level;
std::vector<NodeKey> neighbours;

Node(NodeKey key, uint16_t level);
HnswNode(NodeKey key, uint16_t level);

StatusOr<HnswNodeFieldMetadata> DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const;
void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
rocksdb::WriteBatchBase* batch) const;
void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage);

// For testing purpose
Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
rocksdb::WriteBatchBase* batch) const;
Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
rocksdb::WriteBatchBase* batch) const;
friend class HnswIndex;
};

struct VectorItem {
using NodeKey = Node::NodeKey;
using NodeKey = HnswNode::NodeKey;

NodeKey key;
kqir::NumericArray vector;
Expand All @@ -71,7 +71,7 @@ struct VectorItem {
StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right);

struct HnswIndex {
using NodeKey = Node::NodeKey;
using NodeKey = HnswNode::NodeKey;

SearchKey search_key;
HnswVectorFieldMetadata* metadata;
Expand Down Expand Up @@ -99,7 +99,7 @@ struct HnswIndex {
const std::vector<NodeKey>& entry_points) const;
Status InsertVectorEntryInternal(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, uint16_t layer) const;
Status InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
Status InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
};
Expand Down
2 changes: 1 addition & 1 deletion src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {
GetFixed16(input, &m);
GetFixed32(input, &ef_construction);
GetFixed32(input, &ef_runtime);
GetDouble(input, (double *)(&epsilon));
GetDouble(input, &epsilon);
GetFixed16(input, &num_levels);
return rocksdb::Status::OK();
}
Expand Down
Loading

0 comments on commit f7900fd

Please sign in to comment.