Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 12, 2024
1 parent 22fc5d7 commit 224141f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 27 deletions.
45 changes: 37 additions & 8 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,36 @@ Status HnswNode::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey&
return Status::OK();
}

VectorItem::VectorItem(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata)
: key(std::move(key)), vector(vector), metadata(metadata) {}
VectorItem::VectorItem(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata)
: key(std::move(key)), vector(std::move(vector)), metadata(metadata) {}
Status VectorItem::Create(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata,
VectorItem* out) {
if (metadata->dim != vector.size()) {
return {Status::InvalidArgument, "VectorItem's metadata dimension must be consistent with the vector itself."};
}

*out = VectorItem(std::move(key), vector, metadata);
return Status::OK();
}

Status VectorItem::Create(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata,
VectorItem* out) {
if (metadata->dim != vector.size()) {
return {Status::InvalidArgument, "VectorItem's metadata dimension must be consistent with the vector itself."};
}

*out = VectorItem(std::move(key), std::move(vector), metadata);
return Status::OK();
}

bool VectorItem::operator==(const VectorItem& other) const { return key == other.key; }

bool VectorItem::operator<(const VectorItem& other) const { return key < other.key; }

VectorItem::VectorItem(NodeKey&& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata)
: key(std::move(key)), vector(vector), metadata(metadata) {}

VectorItem::VectorItem(NodeKey&& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata)
: key(std::move(key)), vector(std::move(vector)), metadata(metadata) {}

StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right) {
if (left.metadata->distance_metric != right.metadata->distance_metric || left.metadata->dim != right.metadata->dim)
return {Status::InvalidArgument, "Vectors must be of the same metric and dimension to compute distance."};
Expand Down Expand Up @@ -194,7 +215,9 @@ StatusOr<std::vector<VectorItem>> HnswIndex::DecodeNodesToVectorItems(const std:
continue; // Skip this neighbour if metadata can't be decoded
}
auto neighbour_metadata = neighbour_metadata_status.GetValue();
vector_items.emplace_back(VectorItem(neighbour_key, std::move(neighbour_metadata.vector), metadata));
VectorItem item;
GET_OR_RET(VectorItem::Create(neighbour_key, std::move(neighbour_metadata.vector), metadata, &item));
vector_items.emplace_back(std::move(item));
}
return vector_items;
}
Expand Down Expand Up @@ -265,7 +288,9 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
HnswNode entry_node = HnswNode(entry_point_key, level);
auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key, storage));

auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata);
VectorItem entry_point_vector;
GET_OR_RET(
VectorItem::Create(entry_point_key, std::move(entry_node_metadata.vector), metadata, &entry_point_vector));
auto dist = GET_OR_RET(ComputeSimilarity(target_vector, entry_point_vector));

explore_heap.push(std::make_pair(dist, entry_point_vector));
Expand All @@ -291,7 +316,10 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V

auto neighbour_node = HnswNode(neighbour_key, level);
auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata);

VectorItem neighbour_node_vector;
GET_OR_RET(VectorItem::Create(neighbour_key, std::move(neighbour_node_metadata.vector), metadata,
&neighbour_node_vector));

auto dist = GET_OR_RET(ComputeSimilarity(target_vector, neighbour_node_vector));
explore_heap.push(std::make_pair(dist, neighbour_node_vector));
Expand All @@ -315,7 +343,8 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
uint16_t target_level) const {
auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
auto inserted_vector_item = VectorItem(std::string(key), vector, metadata);
VectorItem inserted_vector_item;
GET_OR_RET(VectorItem::Create(std::string(key), vector, metadata, &inserted_vector_item));
std::vector<VectorItem> nearest_vec_items;

if (metadata->num_levels != 0) {
Expand Down
12 changes: 10 additions & 2 deletions src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,19 @@ struct VectorItem {
kqir::NumericArray vector;
const HnswVectorFieldMetadata* metadata;

VectorItem(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata);
VectorItem(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata);
VectorItem() : metadata(nullptr) {}

static Status Create(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata,
VectorItem* out);
static Status Create(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata,
VectorItem* out);

bool operator==(const VectorItem& other) const;
bool operator<(const VectorItem& other) const;

private:
VectorItem(NodeKey&& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata);
VectorItem(NodeKey&& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata);
};

StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right);
Expand Down
61 changes: 44 additions & 17 deletions tests/cppunit/hnsw_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct HnswIndexTest : TestBase {

HnswIndexTest() {
metadata.vector_type = redis::VectorType::FLOAT64;
metadata.dim = 4;
metadata.dim = 3;
metadata.m = 3;
metadata.distance_metric = redis::DistanceMetric::L2;
auto search_key = redis::SearchKey(ns, idx_name, key);
Expand All @@ -51,9 +51,15 @@ struct HnswIndexTest : TestBase {
};

TEST_F(HnswIndexTest, ComputeSimilarity) {
redis::VectorItem vec1 = {"1", {1.0, 1.2, 1.4, 1.6}, hnsw_index->metadata};
redis::VectorItem vec2 = {"2", {3.0, 3.2, 3.4, 3.6}, hnsw_index->metadata};
redis::VectorItem vec3 = {"3", {1.0, 1.2, 1.4, 1.6}, hnsw_index->metadata}; // identical to vec1
redis::VectorItem vec1;
auto status1 = redis::VectorItem::Create("1", {1.0, 1.2, 1.4}, hnsw_index->metadata, &vec1);
ASSERT_TRUE(status1.IsOK());
redis::VectorItem vec2;
auto status2 = redis::VectorItem::Create("2", {3.0, 3.2, 3.4}, hnsw_index->metadata, &vec2);
ASSERT_TRUE(status2.IsOK());
redis::VectorItem vec3; // identical to vec1
auto status3 = redis::VectorItem::Create("3", {1.0, 1.2, 1.4}, hnsw_index->metadata, &vec3);
ASSERT_TRUE(status3.IsOK());

auto s1 = redis::ComputeSimilarity(vec1, vec3);
ASSERT_TRUE(s1.IsOK());
Expand All @@ -63,18 +69,17 @@ TEST_F(HnswIndexTest, ComputeSimilarity) {
auto s2 = redis::ComputeSimilarity(vec1, vec2);
ASSERT_TRUE(s2.IsOK());
similarity = s2.GetValue();
EXPECT_EQ(similarity, 4.0);
EXPECT_NEAR(similarity, std::sqrt(12), 1e-5);

hnsw_index->metadata->distance_metric = redis::DistanceMetric::IP;
auto s3 = redis::ComputeSimilarity(vec1, vec2);
ASSERT_TRUE(s3.IsOK());
similarity = s3.GetValue();
EXPECT_NEAR(similarity, -17.36, 1e-5);
EXPECT_NEAR(similarity, -(1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4), 1e-5);

hnsw_index->metadata->distance_metric = redis::DistanceMetric::COSINE;
double expected_res =
(1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4 + 1.6 * 3.6) /
std::sqrt((1.0 * 1.0 + 1.2 * 1.2 + 1.4 * 1.4 + 1.6 * 1.6) * (3.0 * 3.0 + 3.2 * 3.2 + 3.4 * 3.4 + 3.6 * 3.6));
double expected_res = (1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4) /
std::sqrt((1.0 * 1.0 + 1.2 * 1.2 + 1.4 * 1.4) * (3.0 * 3.0 + 3.2 * 3.2 + 3.4 * 3.4));
auto s4 = redis::ComputeSimilarity(vec1, vec2);
ASSERT_TRUE(s4.IsOK());
similarity = s4.GetValue();
Expand Down Expand Up @@ -172,13 +177,33 @@ TEST_F(HnswIndexTest, DecodeNodesToVectorItems) {
}

TEST_F(HnswIndexTest, SelectNeighbors) {
redis::VectorItem vec1 = {"1", {1.0, 1.0, 1.0, 1.0}, hnsw_index->metadata};
redis::VectorItem vec2 = {"2", {2.0, 2.0, 2.0, 2.0}, hnsw_index->metadata};
redis::VectorItem vec3 = {"3", {3.0, 3.0, 3.0, 3.0}, hnsw_index->metadata};
redis::VectorItem vec4 = {"4", {4.0, 4.0, 4.0, 4.0}, hnsw_index->metadata};
redis::VectorItem vec5 = {"5", {5.0, 5.0, 5.0, 5.0}, hnsw_index->metadata};
redis::VectorItem vec6 = {"6", {6.0, 6.0, 6.0, 6.0}, hnsw_index->metadata};
redis::VectorItem vec7 = {"7", {7.0, 7.0, 7.0, 7.0}, hnsw_index->metadata};
redis::VectorItem vec1;
auto status1 = redis::VectorItem::Create("1", {1.0, 1.0, 1.0}, hnsw_index->metadata, &vec1);
ASSERT_TRUE(status1.IsOK());

redis::VectorItem vec2;
auto status2 = redis::VectorItem::Create("2", {2.0, 2.0, 2.0}, hnsw_index->metadata, &vec2);
ASSERT_TRUE(status2.IsOK());

redis::VectorItem vec3;
auto status3 = redis::VectorItem::Create("3", {3.0, 3.0, 3.0}, hnsw_index->metadata, &vec3);
ASSERT_TRUE(status3.IsOK());

redis::VectorItem vec4;
auto status4 = redis::VectorItem::Create("4", {4.0, 4.0, 4.0}, hnsw_index->metadata, &vec4);
ASSERT_TRUE(status4.IsOK());

redis::VectorItem vec5;
auto status5 = redis::VectorItem::Create("5", {5.0, 5.0, 5.0}, hnsw_index->metadata, &vec5);
ASSERT_TRUE(status5.IsOK());

redis::VectorItem vec6;
auto status6 = redis::VectorItem::Create("6", {6.0, 6.0, 6.0}, hnsw_index->metadata, &vec6);
ASSERT_TRUE(status6.IsOK());

redis::VectorItem vec7;
auto status7 = redis::VectorItem::Create("7", {7.0, 7.0, 7.0}, hnsw_index->metadata, &vec7);
ASSERT_TRUE(status7.IsOK());

std::vector<redis::VectorItem> candidates = {vec3, vec2};
auto s1 = hnsw_index->SelectNeighbors(vec1, candidates, 1);
Expand Down Expand Up @@ -264,7 +289,9 @@ TEST_F(HnswIndexTest, SearchLayer) {
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

redis::VectorItem target_vector("target", {2.0, 3.0, 4.0}, hnsw_index->metadata);
redis::VectorItem target_vector;
auto status = redis::VectorItem::Create("target", {2.0, 3.0, 4.0}, hnsw_index->metadata, &target_vector);
ASSERT_TRUE(status.IsOK());

// Test with multiple entry points
std::vector<std::string> entry_points = {"node3", "node2"};
Expand Down

0 comments on commit 224141f

Please sign in to comment.