Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Support HasRawData interface (#947)
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm authored Jun 21, 2023
1 parent 998a8b8 commit b9bdbfc
Show file tree
Hide file tree
Showing 21 changed files with 93 additions and 0 deletions.
4 changes: 4 additions & 0 deletions knowhere/index/VecIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <utility>
#include <vector>
#include <climits>
#include <string>

#include "knowhere/common/Dataset.h"
#include "knowhere/common/Exception.h"
Expand Down Expand Up @@ -84,6 +85,9 @@ class VecIndex : public Index {
KNOWHERE_THROW_MSG("GetVectorById not supported yet");
}

virtual bool
HasRawData(const std::string& metric_type) const = 0;

/**
* @brief TopK Query. if the result size is smaller than K, this API will fill the return ids with -1 and distances
* with FLOAT_MIN or FLOAT_MAX depends on the metric type.
Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/VecIndexThreadPoolWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include <string>
#include <memory>
#include <utility>

Expand Down Expand Up @@ -63,6 +64,11 @@ class VecIndexThreadPoolWrapper : public VecIndex {
return index_->GetVectorById(dataset, config);
}

bool
HasRawData(const std::string& metric_type) const override {
return index_->HasRawData(metric_type);
}

DatasetPtr
Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset) override {
return thread_pool_->push([&]() { return this->index_->Query(dataset, config, bitset); }).get();
Expand Down
5 changes: 5 additions & 0 deletions knowhere/index/vector_index/IndexAnnoy.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class IndexAnnoy : public VecIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexBinaryIDMAP.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <string>

#include "knowhere/index/VecIndex.h"
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
Expand Down Expand Up @@ -45,6 +46,11 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexBinaryIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <string>

#include <faiss/IndexIVF.h>

Expand Down Expand Up @@ -50,6 +51,11 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
5 changes: 5 additions & 0 deletions knowhere/index/vector_index/IndexDiskANN.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class IndexDiskANN : public VecIndex {
KNOWHERE_THROW_MSG("DiskANN doesn't support GetVectorById.");
}

bool
HasRawData(const std::string& /*metric_type*/) const override {
return false;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <memory>
#include <unordered_set>
#include <string>

#include "hnswlib/hnswlib/hnswlib.h"
#include "knowhere/common/Exception.h"
Expand Down Expand Up @@ -55,6 +56,11 @@ class IndexHNSW : public VecIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexIDMAP.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <string>

#include "knowhere/index/VecIndex.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
Expand Down Expand Up @@ -45,6 +46,11 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <string>

#include <faiss/IndexIVF.h>

Expand Down Expand Up @@ -50,6 +51,11 @@ class IVF : public VecIndex, public FaissBaseIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexIVFPQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <memory>
#include <utility>
#include <string>

#include "knowhere/index/vector_index/IndexIVF.h"

Expand All @@ -35,6 +36,11 @@ class IVFPQ : public IVF {
KNOWHERE_THROW_MSG("GetVectorById not supported yet");
}

bool
HasRawData(const std::string& /*metric_type*/) const override {
return false;
}

void
Train(const DatasetPtr&, const Config&) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_index/IndexIVFSQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <memory>
#include <utility>
#include <string>

#include "knowhere/index/vector_index/IndexIVF.h"

Expand All @@ -35,6 +36,11 @@ class IVFSQ : public IVF {
KNOWHERE_THROW_MSG("GetVectorById not supported yet");
}

bool
HasRawData(const std::string& /*metric_type*/) const override {
return false;
}

void
Train(const DatasetPtr&, const Config&) override;

Expand Down
6 changes: 6 additions & 0 deletions knowhere/index/vector_offset_index/IndexIVF_NM.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <mutex>
#include <utility>
#include <vector>
#include <string>

#include <faiss/IndexIVF.h>

Expand Down Expand Up @@ -51,6 +52,11 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

bool
HasRawData(const std::string& /*metric_type*/) const override {
return true;
}

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
5 changes: 5 additions & 0 deletions unittest/AsyncIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ class AsyncIndex : public VecIndex {
return index_->Load(index_binary);
}

bool
HasRawData(const std::string& metric_type) const override {
return false;
}

void
BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override {
index_->BuildAll(dataset_ptr, config);
Expand Down
1 change: 1 addition & 0 deletions unittest/test_annoy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ TEST_P(AnnoyTest, annoy_basic) {
ASSERT_EQ(index_->Dim(), dim);
ASSERT_GT(index_->Size(), 0);

ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
auto result = index_->GetVectorById(id_dataset, conf_);
AssertVec(result, base_dataset, id_dataset, nq, dim);

Expand Down
1 change: 1 addition & 0 deletions unittest/test_binaryidmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_GT(index_->Size(), 0);

ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
auto result = index_->GetVectorById(id_dataset, conf_);
AssertBinVec(result, base_dataset, id_dataset, nq, dim);

Expand Down
1 change: 1 addition & 0 deletions unittest/test_binaryivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_GT(index_->Size(), 0);

ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
auto result = index_->GetVectorById(id_dataset, conf_);
AssertBinVec(result, base_dataset, id_dataset, nq, dim);

Expand Down
13 changes: 13 additions & 0 deletions unittest/test_diskann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include <gtest/gtest-param-test.h>
#include <gtest/gtest.h>

#include <cmath>
Expand Down Expand Up @@ -506,6 +507,18 @@ TEST_P(DiskANNTest, knn_search_test) {
CheckDistanceError(raw_data_, query_data_, result, metric_, num_queries_, dim_, kK, num_rows_, is_large_dim_);
}

TEST_P(DiskANNTest, get_vector_by_id) {
knowhere::Config cfg;
cfg.clear();
knowhere::DiskANNPrepareConfig::Set(cfg, prep_conf);
EXPECT_TRUE(diskann->Prepare(cfg));
cfg.clear();
knowhere::DiskANNQueryConfig::Set(cfg, query_conf);

ASSERT_FALSE(diskann->HasRawData(metric_));
ASSERT_ANY_THROW(diskann->GetVectorById(nullptr, cfg));
}

TEST_P(DiskANNTest, knn_search_with_accelerate_build_test) {
if (is_large_dim_) {
GTEST_SKIP() << "Skip build accelerate test for large dim.";
Expand Down
1 change: 1 addition & 0 deletions unittest/test_hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ TEST_P(HNSWTest, HNSW_basic) {

index_->Load(bs);

ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
auto result = index_->GetVectorById(id_dataset, conf_);
AssertVec(result, base_dataset, id_dataset, nq, dim);

Expand Down
1 change: 1 addition & 0 deletions unittest/test_idmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ TEST_P(IDMAPTest, idmap_basic) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_GT(index_->Size(), 0);

ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
auto result = index_->GetVectorById(id_dataset, conf_);
AssertVec(result, base_dataset, id_dataset, nq, dim);

Expand Down
1 change: 1 addition & 0 deletions unittest/test_ivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ TEST_P(IVFTest, ivf_basic) {
EXPECT_EQ(index_->Dim(), dim);
ASSERT_GT(index_->Size(), 0);
if (index_mode_ == knowhere::IndexMode::MODE_CPU) {
ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
ASSERT_ANY_THROW(index_->GetVectorById(id_dataset, conf_));
}

Expand Down
1 change: 1 addition & 0 deletions unittest/test_ivf_nm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ TEST_P(IVFNMTest, ivfnm_basic) {
LoadRawData(index_, base_dataset, conf_);

if (index_mode_ == knowhere::IndexMode::MODE_CPU) {
ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_)));
auto result = index_->GetVectorById(id_dataset, conf_);
AssertVec(result, base_dataset, id_dataset, nq, dim);

Expand Down

0 comments on commit b9bdbfc

Please sign in to comment.