diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc index 315f9e9a29a..f6da237569f 100644 --- a/src/search/hnsw_indexer.cc +++ b/src/search/hnsw_indexer.cc @@ -34,9 +34,9 @@ namespace redis { -Node::Node(NodeKey key, uint16_t level) : key(std::move(key)), level(level) {} +HnswNode::HnswNode(NodeKey key, uint16_t level) : key(std::move(key)), level(level) {} -StatusOr Node::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const { +StatusOr HnswNode::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const { auto node_index_key = search_key.ConstructHnswNode(level, key); rocksdb::PinnableSlice value; auto s = storage->Get(rocksdb::ReadOptions(), storage->GetCFHandle(ColumnFamilyID::Search), node_index_key, &value); @@ -48,14 +48,14 @@ StatusOr Node::DecodeMetadata(const SearchKey& search_key return metadata; } -void Node::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, - ObserverOrUniquePtr& batch) const { +void HnswNode::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const { std::string updated_metadata; node_meta->Encode(&updated_metadata); batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), search_key.ConstructHnswNode(level, key), updated_metadata); } -void Node::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) { +void HnswNode::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) { neighbours.clear(); auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key); util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); @@ -71,8 +71,8 @@ void Node::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storag } } -Status Node::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, - ObserverOrUniquePtr& batch) const { +Status HnswNode::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const { auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key); batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key, Slice()); @@ -82,8 +82,8 @@ Status Node::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_ return Status::OK(); } -Status Node::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, - ObserverOrUniquePtr& batch) const { +Status HnswNode::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const { auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key); auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key); if (!s.ok()) { @@ -187,8 +187,12 @@ StatusOr> HnswIndex::DecodeNodesToVectorItems(const std: vector_items.reserve(node_keys.size()); for (const auto& neighbour_key : node_keys) { - Node neighbour_node(neighbour_key, level); - auto neighbour_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage)); + HnswNode neighbour_node(neighbour_key, level); + auto neighbour_metadata_status = neighbour_node.DecodeMetadata(search_key, storage); + if (!neighbour_metadata_status.IsOK()) { + continue; // Skip this neighbour if metadata can't be decoded + } + auto neighbour_metadata = neighbour_metadata_status.GetValue(); vector_items.emplace_back(VectorItem(neighbour_key, std::move(neighbour_metadata.vector), metadata)); } return vector_items; @@ -257,7 +261,7 @@ StatusOr> HnswIndex::SearchLayer(uint16_t level, const V std::priority_queue> result_heap; for (const auto& entry_point_key : entry_points) { - Node entry_node = Node(entry_point_key, level); + HnswNode entry_node = HnswNode(entry_point_key, level); auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key, storage)); auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata); @@ -269,13 +273,13 @@ StatusOr> HnswIndex::SearchLayer(uint16_t level, const V } while (!explore_heap.empty()) { - auto [dist, current_vector] = explore_heap.top(); + auto [dist, current_vector] = std::move(explore_heap.top()); explore_heap.pop(); if (dist > result_heap.top().first) { break; } - auto current_node = Node(current_vector.key, level); + auto current_node = HnswNode(current_vector.key, level); current_node.DecodeNeighbours(search_key, storage); for (const auto& neighbour_key : current_node.neighbours) { @@ -284,7 +288,7 @@ StatusOr> HnswIndex::SearchLayer(uint16_t level, const V } visited.insert(neighbour_key); - auto neighbour_node = Node(neighbour_key, level); + auto neighbour_node = HnswNode(neighbour_key, level); auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage)); auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata); @@ -327,7 +331,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu for (; level >= 0; level--) { nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata->ef_construction, entry_points)); auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level)); - auto node = Node(std::string(key), level); + auto node = HnswNode(std::string(key), level); auto m_max = level == 0 ? 2 * metadata->m : metadata->m; std::unordered_set connected_edges_set; @@ -339,7 +343,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu }; // Check if candidate node has room after some other nodes' are pruned in current batch - auto has_room_after_deletions = [&](const Node& candidate_node, uint16_t candidate_node_num_neighbours) { + auto has_room_after_deletions = [&](const HnswNode& candidate_node, uint16_t candidate_node_num_neighbours) { auto it = deleted_edges_map.find(candidate_node.key); if (it != deleted_edges_map.end()) { auto num_deleted_edges = static_cast(it->second.size()); @@ -349,7 +353,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu }; for (const auto& candidate_vec : candidate_vec_items) { - auto candidate_node = Node(candidate_vec.key, level); + auto candidate_node = HnswNode(candidate_vec.key, level); auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(search_key, storage)); uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours; @@ -365,27 +369,30 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu auto candidate_node_neighbour_vec_items = GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours, level, search_key, storage, metadata)); candidate_node_neighbour_vec_items.push_back(inserted_vector_item); - auto sorted_neighbours = GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level)); + auto sorted_neighbours_by_distance = + GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level)); + + bool inserted_node_is_selected = + std::find(sorted_neighbours_by_distance.begin(), sorted_neighbours_by_distance.end(), + inserted_vector_item) != sorted_neighbours_by_distance.end(); - bool inserted_node_is_selected = std::find(sorted_neighbours.begin(), sorted_neighbours.end(), - inserted_vector_item) != sorted_neighbours.end(); if (inserted_node_is_selected) { // Add the edge between candidate and inserted node GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch)); connected_edges_set.insert(candidate_node.key); - auto find_deleted_item = [&](const std::vector& candidate_node_neighbour_vec_items, - const std::vector& sorted_neighbours) -> VectorItem { - auto it = std::find_if(candidate_node_neighbour_vec_items.begin(), candidate_node_neighbour_vec_items.end(), - [&](const VectorItem& item) { - return std::find(sorted_neighbours.begin(), sorted_neighbours.end(), item) == - sorted_neighbours.end(); - }); + auto find_deleted_item = [&](const std::vector& candidate_neighbours, + const std::vector& selected_neighbours) -> VectorItem { + auto it = + std::find_if(candidate_neighbours.begin(), candidate_neighbours.end(), [&](const VectorItem& item) { + return std::find(selected_neighbours.begin(), selected_neighbours.end(), item) == + selected_neighbours.end(); + }); return *it; }; // Remove the edge for candidate and the pruned node - auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours); + auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours_by_distance); GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level, batch)); deleted_edges_map[candidate_node.key].insert(deleted_node.key); deleted_edges_map[deleted_node.key].insert(candidate_node.key); @@ -394,12 +401,12 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu // Update inserted node metadata HnswNodeFieldMetadata node_metadata(static_cast(connected_edges_set.size()), vector); - node.PutMetadata(&node_metadata, search_key, storage, batch); + node.PutMetadata(&node_metadata, search_key, storage, batch.Get()); // Update modified nodes metadata for (const auto& node_edges : deleted_edges_map) { auto& current_node_key = node_edges.first; - auto current_node = Node(current_node_key, level); + auto current_node = HnswNode(current_node_key, level); auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage)); auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size(); if (connected_edges_set.count(current_node_key) != 0) { @@ -407,14 +414,14 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu connected_edges_set.erase(current_node_key); } current_node_metadata.num_neighbours = new_num_neighbours; - current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch); + current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch.Get()); } for (const auto& current_node_key : connected_edges_set) { - auto current_node = Node(current_node_key, level); + auto current_node = HnswNode(current_node_key, level); HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage)); current_node_metadata.num_neighbours++; - current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch); + current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch.Get()); } entry_points.clear(); @@ -423,16 +430,16 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu } } } else { - auto node = Node(std::string(key), 0); + auto node = HnswNode(std::string(key), 0); HnswNodeFieldMetadata node_metadata(0, vector); - node.PutMetadata(&node_metadata, search_key, storage, batch); + node.PutMetadata(&node_metadata, search_key, storage, batch.Get()); metadata->num_levels = 1; } while (target_level > metadata->num_levels - 1) { - auto node = Node(std::string(key), metadata->num_levels); + auto node = HnswNode(std::string(key), metadata->num_levels); HnswNodeFieldMetadata node_metadata(0, vector); - node.PutMetadata(&node_metadata, search_key, storage, batch); + node.PutMetadata(&node_metadata, search_key, storage, batch.Get()); metadata->num_levels++; } @@ -444,16 +451,16 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu return Status::OK(); } -Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vector, +Status HnswIndex::InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector, ObserverOrUniquePtr& batch) { auto target_level = RandomizeLayer(); - return InsertVectorEntryInternal(key, std::move(vector), batch, target_level); + return InsertVectorEntryInternal(key, vector, batch, target_level); } Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr& batch) const { std::string node_key(key); for (uint16_t level = 0; level < metadata->num_levels; level++) { - auto node = Node(node_key, level); + auto node = HnswNode(node_key, level); auto node_metadata_status = node.DecodeMetadata(search_key, storage); if (!node_metadata_status.IsOK()) { break; @@ -469,10 +476,10 @@ Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr neighbours; - Node(NodeKey key, uint16_t level); + HnswNode(NodeKey key, uint16_t level); StatusOr DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const; void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, - ObserverOrUniquePtr& batch) const; + rocksdb::WriteBatchBase* batch) const; void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage); // For testing purpose Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, - ObserverOrUniquePtr& batch) const; + rocksdb::WriteBatchBase* batch) const; Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, - ObserverOrUniquePtr& batch) const; + rocksdb::WriteBatchBase* batch) const; friend class HnswIndex; }; struct VectorItem { - using NodeKey = Node::NodeKey; + using NodeKey = HnswNode::NodeKey; NodeKey key; kqir::NumericArray vector; @@ -71,7 +71,7 @@ struct VectorItem { StatusOr ComputeSimilarity(const VectorItem& left, const VectorItem& right); struct HnswIndex { - using NodeKey = Node::NodeKey; + using NodeKey = HnswNode::NodeKey; SearchKey search_key; HnswVectorFieldMetadata* metadata; @@ -99,7 +99,7 @@ struct HnswIndex { const std::vector& entry_points) const; Status InsertVectorEntryInternal(std::string_view key, const kqir::NumericArray& vector, ObserverOrUniquePtr& batch, uint16_t layer) const; - Status InsertVectorEntry(std::string_view key, kqir::NumericArray vector, + Status InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector, ObserverOrUniquePtr& batch); Status DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr& batch) const; }; diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index afba656d90e..0a757457aae 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -402,7 +402,7 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata { GetFixed16(input, &m); GetFixed32(input, &ef_construction); GetFixed32(input, &ef_runtime); - GetDouble(input, (double *)(&epsilon)); + GetDouble(input, &epsilon); GetFixed16(input, &num_levels); return rocksdb::Status::OK(); } diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc index dbcf851049e..ba0065fa0ce 100644 --- a/tests/cppunit/hnsw_index_test.cc +++ b/tests/cppunit/hnsw_index_test.cc @@ -141,18 +141,18 @@ TEST_F(HnswIndexTest, DecodeNodesToVectorItems) { std::string node_key2 = "node2"; std::string node_key3 = "node3"; - redis::Node node1(node_key1, layer); - redis::Node node2(node_key2, layer); - redis::Node node3(node_key3, layer); + redis::HnswNode node1(node_key1, layer); + redis::HnswNode node2(node_key2, layer); + redis::HnswNode node3(node_key3, layer); redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); auto batch = storage_->GetWriteBatchBase(); - node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage, batch); - node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage, batch); - node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage, batch); + node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage, batch.Get()); auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); ASSERT_TRUE(s.ok()); @@ -221,11 +221,11 @@ TEST_F(HnswIndexTest, SearchLayer) { std::string node_key4 = "node4"; std::string node_key5 = "node5"; - redis::Node node1(node_key1, layer); - redis::Node node2(node_key2, layer); - redis::Node node3(node_key3, layer); - redis::Node node4(node_key4, layer); - redis::Node node5(node_key5, layer); + redis::HnswNode node1(node_key1, layer); + redis::HnswNode node2(node_key2, layer); + redis::HnswNode node3(node_key3, layer); + redis::HnswNode node4(node_key4, layer); + redis::HnswNode node5(node_key5, layer); redis::HnswNodeFieldMetadata metadata1(0, {1.0, 2.0, 3.0}); redis::HnswNodeFieldMetadata metadata2(0, {4.0, 5.0, 6.0}); @@ -235,31 +235,31 @@ TEST_F(HnswIndexTest, SearchLayer) { // Add Nodes auto batch = storage_->GetWriteBatchBase(); - node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage, batch); - node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage, batch); - node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage, batch); - node4.PutMetadata(&metadata4, hnsw_index->search_key, hnsw_index->storage, batch); - node5.PutMetadata(&metadata5, hnsw_index->search_key, hnsw_index->storage, batch); + node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node4.PutMetadata(&metadata4, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node5.PutMetadata(&metadata5, hnsw_index->search_key, hnsw_index->storage, batch.Get()); auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); ASSERT_TRUE(s.ok()); // Add Neighbours batch = storage_->GetWriteBatchBase(); - auto s1 = node1.AddNeighbour("node2", hnsw_index->search_key, hnsw_index->storage, batch); + auto s1 = node1.AddNeighbour("node2", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s1.IsOK()); - auto s2 = node1.AddNeighbour("node4", hnsw_index->search_key, hnsw_index->storage, batch); + auto s2 = node1.AddNeighbour("node4", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s2.IsOK()); - auto s3 = node2.AddNeighbour("node1", hnsw_index->search_key, hnsw_index->storage, batch); + auto s3 = node2.AddNeighbour("node1", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s3.IsOK()); - auto s4 = node2.AddNeighbour("node3", hnsw_index->search_key, hnsw_index->storage, batch); + auto s4 = node2.AddNeighbour("node3", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s1.IsOK()); - auto s5 = node3.AddNeighbour("node2", hnsw_index->search_key, hnsw_index->storage, batch); + auto s5 = node3.AddNeighbour("node2", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s5.IsOK()); - auto s6 = node3.AddNeighbour("node5", hnsw_index->search_key, hnsw_index->storage, batch); + auto s6 = node3.AddNeighbour("node5", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s6.IsOK()); - auto s7 = node4.AddNeighbour("node1", hnsw_index->search_key, hnsw_index->storage, batch); + auto s7 = node4.AddNeighbour("node1", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s7.IsOK()); - auto s8 = node5.AddNeighbour("node3", hnsw_index->search_key, hnsw_index->storage, batch); + auto s8 = node5.AddNeighbour("node3", hnsw_index->search_key, hnsw_index->storage, batch.Get()); ASSERT_TRUE(s8.IsOK()); s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); ASSERT_TRUE(s.ok()); @@ -334,13 +334,13 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { decoded_metadata.Decode(&value); ASSERT_TRUE(decoded_metadata.num_levels == 2); - redis::Node node1_layer0(key1, 0); + redis::HnswNode node1_layer0(key1, 0); auto s2 = node1_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s2.IsOK()); redis::HnswNodeFieldMetadata node1_layer0_meta = s2.GetValue(); EXPECT_EQ(node1_layer0_meta.num_neighbours, 0); - redis::Node node1_layer1(key1, 1); + redis::HnswNode node1_layer1(key1, 1); auto s3 = node1_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s3.IsOK()); redis::HnswNodeFieldMetadata node1_layer1_meta = s2.GetValue(); @@ -369,23 +369,23 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { EXPECT_EQ(node1_layer1.neighbours.size(), 1); EXPECT_EQ(node1_layer1.neighbours[0], "n2"); - redis::Node node2_layer0(key2, 0); + redis::HnswNode node2_layer0(key2, 0); node2_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); EXPECT_EQ(node2_layer0.neighbours.size(), 1); EXPECT_EQ(node2_layer0.neighbours[0], "n1"); - redis::Node node2_layer1(key2, 1); + redis::HnswNode node2_layer1(key2, 1); node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); EXPECT_EQ(node2_layer1.neighbours.size(), 1); EXPECT_EQ(node2_layer1.neighbours[0], "n1"); - redis::Node node2_layer2(key2, 2); + redis::HnswNode node2_layer2(key2, 2); auto s5 = node2_layer2.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s5.IsOK()); redis::HnswNodeFieldMetadata node2_layer2_meta = s5.GetValue(); EXPECT_EQ(node2_layer2_meta.num_neighbours, 0); - redis::Node node2_layer3(key2, 3); + redis::HnswNode node2_layer3(key2, 3); auto s6 = node2_layer3.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s6.IsOK()); redis::HnswNodeFieldMetadata node2_layer3_meta = s6.GetValue(); @@ -406,7 +406,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { decoded_metadata.Decode(&value); ASSERT_TRUE(decoded_metadata.num_levels == 4); - redis::Node node3_layer2(key3, target_level); + redis::HnswNode node3_layer2(key3, target_level); auto s8 = node3_layer2.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s8.IsOK()); redis::HnswNodeFieldMetadata node3_layer2_meta = s8.GetValue(); @@ -415,7 +415,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { EXPECT_EQ(node3_layer2.neighbours.size(), 1); EXPECT_EQ(node3_layer2.neighbours[0], "n2"); - redis::Node node3_layer1(key3, 1); + redis::HnswNode node3_layer1(key3, 1); auto s9 = node3_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s9.IsOK()); redis::HnswNodeFieldMetadata node3_layer1_meta = s9.GetValue(); @@ -434,7 +434,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); ASSERT_TRUE(s.ok()); - redis::Node node4_layer0(key4, 0); + redis::HnswNode node4_layer0(key4, 0); auto s11 = node4_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s11.IsOK()); redis::HnswNodeFieldMetadata node4_layer0_meta = s11.GetValue(); @@ -492,7 +492,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { actual_set = {node3_layer1.neighbours.begin(), node3_layer1.neighbours.end()}; EXPECT_EQ(actual_set, expected_set); - redis::Node node4_layer1(key4, 1); + redis::HnswNode node4_layer1(key4, 1); auto s18 = node4_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s18.IsOK()); auto node4_layer1_meta = s18.GetValue(); @@ -502,7 +502,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { actual_set = {node4_layer1.neighbours.begin(), node4_layer1.neighbours.end()}; EXPECT_EQ(actual_set, expected_set); - redis::Node node5_layer1(key5, 1); + redis::HnswNode node5_layer1(key5, 1); auto s19 = node5_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s19.IsOK()); auto node5_layer1_meta = s19.GetValue(); @@ -521,7 +521,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { actual_set = {node1_layer0.neighbours.begin(), node1_layer0.neighbours.end()}; EXPECT_EQ(actual_set, expected_set); - redis::Node node5_layer0(key5, 0); + redis::HnswNode node5_layer0(key5, 0); auto s21 = node5_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s21.IsOK()); auto node5_layer0_meta = s21.GetValue(); @@ -607,7 +607,7 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { actual_set = {node1_layer0.neighbours.begin(), node1_layer0.neighbours.end()}; EXPECT_EQ(actual_set, expected_set); - redis::Node node3_layer0(key3, 0); + redis::HnswNode node3_layer0(key3, 0); auto s33 = node3_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); ASSERT_TRUE(s33.IsOK()); auto node3_layer0_meta = s33.GetValue(); diff --git a/tests/cppunit/hnsw_node_test.cc b/tests/cppunit/hnsw_node_test.cc index 71a96f7b29e..5fadf9927a5 100644 --- a/tests/cppunit/hnsw_node_test.cc +++ b/tests/cppunit/hnsw_node_test.cc @@ -32,54 +32,54 @@ #include "storage/storage.h" struct NodeTest : public TestBase { - std::string ns = "node_test_ns"; - std::string idx_name = "node_test_idx"; + std::string ns = "hnsw_node_test_ns"; + std::string idx_name = "hnsw_node_test_idx"; std::string key = "vector"; - redis::SearchKey search_key_; + redis::SearchKey search_key; - NodeTest() : search_key_(ns, idx_name, key) {} + NodeTest() : search_key(ns, idx_name, key) {} void TearDown() override {} }; TEST_F(NodeTest, PutAndDecodeMetadata) { uint16_t layer = 0; - redis::Node node1("node1", layer); - redis::Node node2("node2", layer); - redis::Node node3("node3", layer); + redis::HnswNode node1("node1", layer); + redis::HnswNode node2("node2", layer); + redis::HnswNode node3("node3", layer); redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); auto batch = storage_->GetWriteBatchBase(); - node1.PutMetadata(&metadata1, search_key_, storage_.get(), batch); - node2.PutMetadata(&metadata2, search_key_, storage_.get(), batch); - node3.PutMetadata(&metadata3, search_key_, storage_.get(), batch); + node1.PutMetadata(&metadata1, search_key, storage_.get(), batch.Get()); + node2.PutMetadata(&metadata2, search_key, storage_.get(), batch.Get()); + node3.PutMetadata(&metadata3, search_key, storage_.get(), batch.Get()); auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); ASSERT_TRUE(s.ok()); - auto decoded_metadata1 = node1.DecodeMetadata(search_key_, storage_.get()); + auto decoded_metadata1 = node1.DecodeMetadata(search_key, storage_.get()); ASSERT_TRUE(decoded_metadata1.IsOK()); ASSERT_EQ(decoded_metadata1.GetValue().num_neighbours, 0); ASSERT_EQ(decoded_metadata1.GetValue().vector, std::vector({1, 2, 3})); - auto decoded_metadata2 = node2.DecodeMetadata(search_key_, storage_.get()); + auto decoded_metadata2 = node2.DecodeMetadata(search_key, storage_.get()); ASSERT_TRUE(decoded_metadata2.IsOK()); ASSERT_EQ(decoded_metadata2.GetValue().num_neighbours, 0); ASSERT_EQ(decoded_metadata2.GetValue().vector, std::vector({4, 5, 6})); - auto decoded_metadata3 = node3.DecodeMetadata(search_key_, storage_.get()); + auto decoded_metadata3 = node3.DecodeMetadata(search_key, storage_.get()); ASSERT_TRUE(decoded_metadata3.IsOK()); ASSERT_EQ(decoded_metadata3.GetValue().num_neighbours, 0); ASSERT_EQ(decoded_metadata3.GetValue().vector, std::vector({7, 8, 9})); // Prepare edges between node1 and node2 batch = storage_->GetWriteBatchBase(); - auto edge1 = search_key_.ConstructHnswEdge(layer, "node1", "node2"); - auto edge2 = search_key_.ConstructHnswEdge(layer, "node2", "node1"); - auto edge3 = search_key_.ConstructHnswEdge(layer, "node2", "node3"); - auto edge4 = search_key_.ConstructHnswEdge(layer, "node3", "node2"); + auto edge1 = search_key.ConstructHnswEdge(layer, "node1", "node2"); + auto edge2 = search_key.ConstructHnswEdge(layer, "node2", "node1"); + auto edge3 = search_key.ConstructHnswEdge(layer, "node2", "node3"); + auto edge4 = search_key.ConstructHnswEdge(layer, "node3", "node2"); batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge1, Slice()); batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge2, Slice()); @@ -88,27 +88,27 @@ TEST_F(NodeTest, PutAndDecodeMetadata) { s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); ASSERT_TRUE(s.ok()); - node1.DecodeNeighbours(search_key_, storage_.get()); + node1.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node1.neighbours.size(), 1); EXPECT_EQ(node1.neighbours[0], "node2"); - node2.DecodeNeighbours(search_key_, storage_.get()); + node2.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node2.neighbours.size(), 2); std::unordered_set expected_neighbours = {"node1", "node3"}; std::unordered_set actual_neighbours(node2.neighbours.begin(), node2.neighbours.end()); EXPECT_EQ(actual_neighbours, expected_neighbours); - node3.DecodeNeighbours(search_key_, storage_.get()); + node3.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node3.neighbours.size(), 1); EXPECT_EQ(node3.neighbours[0], "node2"); } TEST_F(NodeTest, ModifyNeighbours) { uint16_t layer = 1; - redis::Node node1("node1", layer); - redis::Node node2("node2", layer); - redis::Node node3("node3", layer); - redis::Node node4("node4", layer); + redis::HnswNode node1("node1", layer); + redis::HnswNode node2("node2", layer); + redis::HnswNode node3("node3", layer); + redis::HnswNode node4("node4", layer); redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); @@ -117,49 +117,49 @@ TEST_F(NodeTest, ModifyNeighbours) { // Add Nodes auto batch1 = storage_->GetWriteBatchBase(); - node1.PutMetadata(&metadata1, search_key_, storage_.get(), batch1); - node2.PutMetadata(&metadata2, search_key_, storage_.get(), batch1); - node3.PutMetadata(&metadata3, search_key_, storage_.get(), batch1); - node4.PutMetadata(&metadata4, search_key_, storage_.get(), batch1); + node1.PutMetadata(&metadata1, search_key, storage_.get(), batch1.Get()); + node2.PutMetadata(&metadata2, search_key, storage_.get(), batch1.Get()); + node3.PutMetadata(&metadata3, search_key, storage_.get(), batch1.Get()); + node4.PutMetadata(&metadata4, search_key, storage_.get(), batch1.Get()); auto s = storage_->Write(storage_->DefaultWriteOptions(), batch1->GetWriteBatch()); ASSERT_TRUE(s.ok()); // Add Edges auto batch2 = storage_->GetWriteBatchBase(); - auto s1 = node1.AddNeighbour("node2", search_key_, storage_.get(), batch2); + auto s1 = node1.AddNeighbour("node2", search_key, storage_.get(), batch2.Get()); ASSERT_TRUE(s1.IsOK()); - auto s2 = node2.AddNeighbour("node1", search_key_, storage_.get(), batch2); + auto s2 = node2.AddNeighbour("node1", search_key, storage_.get(), batch2.Get()); ASSERT_TRUE(s2.IsOK()); - auto s3 = node2.AddNeighbour("node3", search_key_, storage_.get(), batch2); + auto s3 = node2.AddNeighbour("node3", search_key, storage_.get(), batch2.Get()); ASSERT_TRUE(s3.IsOK()); - auto s4 = node3.AddNeighbour("node2", search_key_, storage_.get(), batch2); + auto s4 = node3.AddNeighbour("node2", search_key, storage_.get(), batch2.Get()); ASSERT_TRUE(s4.IsOK()); s = storage_->Write(storage_->DefaultWriteOptions(), batch2->GetWriteBatch()); ASSERT_TRUE(s.ok()); - node1.DecodeNeighbours(search_key_, storage_.get()); + node1.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node1.neighbours.size(), 1); EXPECT_EQ(node1.neighbours[0], "node2"); - node2.DecodeNeighbours(search_key_, storage_.get()); + node2.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node2.neighbours.size(), 2); std::unordered_set expected_neighbours = {"node1", "node3"}; std::unordered_set actual_neighbours(node2.neighbours.begin(), node2.neighbours.end()); EXPECT_EQ(actual_neighbours, expected_neighbours); - node3.DecodeNeighbours(search_key_, storage_.get()); + node3.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node3.neighbours.size(), 1); EXPECT_EQ(node3.neighbours[0], "node2"); // Remove Edges auto batch3 = storage_->GetWriteBatchBase(); - auto s5 = node2.RemoveNeighbour("node3", search_key_, storage_.get(), batch3); + auto s5 = node2.RemoveNeighbour("node3", search_key, storage_.get(), batch3.Get()); ASSERT_TRUE(s5.IsOK()); s = storage_->Write(storage_->DefaultWriteOptions(), batch3->GetWriteBatch()); ASSERT_TRUE(s.ok()); - node2.DecodeNeighbours(search_key_, storage_.get()); + node2.DecodeNeighbours(search_key, storage_.get()); EXPECT_EQ(node2.neighbours.size(), 1); EXPECT_EQ(node2.neighbours[0], "node1"); }