From 28076defad5a2f2ca1cb4b55cba309ecca27427d Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Mon, 22 Apr 2024 16:02:24 -0700 Subject: [PATCH] Support filter and nested field in faiss engine radial search Signed-off-by: Junqiu Lei --- CHANGELOG.md | 2 +- jni/cmake/init-faiss.cmake | 3 +- jni/include/faiss_wrapper.h | 20 ++- .../org_opensearch_knn_jni_FaissService.h | 14 +- ...patch-to-support-range-search-params.patch | 53 ++++++ jni/src/faiss_wrapper.cpp | 72 +++++++- .../org_opensearch_knn_jni_FaissService.cpp | 19 ++- jni/tests/faiss_wrapper_test.cpp | 156 ++++++++++++++++- .../opensearch/knn/common/KNNConstants.java | 1 + .../knn/index/query/KNNQueryBuilder.java | 92 +++++----- .../opensearch/knn/index/query/KNNWeight.java | 5 +- .../org/opensearch/knn/jni/FaissService.java | 31 +++- .../org/opensearch/knn/jni/JNIService.java | 21 ++- .../org/opensearch/knn/index/FaissIT.java | 157 +++++++++++------- .../opensearch/knn/index/NestedSearchIT.java | 75 ++++++++- .../knn/index/query/KNNWeightTests.java | 9 +- 16 files changed, 597 insertions(+), 133 deletions(-) create mode 100644 jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch diff --git a/CHANGELOG.md b/CHANGELOG.md index b4b6988ff..4753eb962 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features -* Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617) +* Support radial search in k-NN plugin [#814](https://github.com/opensearch-project/k-NN/issues/814) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/jni/cmake/init-faiss.cmake b/jni/cmake/init-faiss.cmake index 44dd4442a..4cb30db50 100644 --- a/jni/cmake/init-faiss.cmake +++ b/jni/cmake/init-faiss.cmake @@ -13,13 +13,14 @@ if (NOT EXISTS ${FAISS_REPO_DIR}) endif () # Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu. -find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH) +find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH) # If it exists, apply patches if (EXISTS ${PATCH_FILE}) message(STATUS "Applying custom patches.") execute_process(COMMAND git am --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) execute_process(COMMAND git am --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git am --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) if(RESULT_CODE) message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}") endif() diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index da67c0f59..958eca8ac 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -74,6 +74,22 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + /* + * Perform a range search with filter against the index located in memory at indexPointerJ. + * + * @param indexPointerJ - pointer to the index + * @param queryVectorJ - the query vector + * @param radiusJ - the radius for the range search + * @param maxResultsWindowJ - the maximum number of results to return + * @param filterIdsJ - the filter ids + * @param filterIdsTypeJ - the filter ids type + * @param parentIdsJ - the parent ids + * + * @return an array of RangeQueryResults + */ + jobjectArray RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + /* * Perform a range search against the index located in memory at indexPointerJ. * @@ -81,10 +97,12 @@ namespace knn_jni { * @param queryVectorJ - the query vector * @param radiusJ - the radius for the range search * @param maxResultsWindowJ - the maximum number of results to return + * @param parentIdsJ - the parent ids + * * @return an array of RangeQueryResults */ jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultsWindowJ); + jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ); } } diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 3715730ab..e16677db7 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -124,11 +124,19 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors /* * Class: org_opensearch_knn_jni_FaissService -* Method: rangeSearchIndex -* Signature: (J[F[F)J +* Method: rangeSearchIndexWithFilter +* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult; */ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jlongArray, jint, jintArray); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: rangeSearchIndex + * Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult; + */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex - (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint); + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jintArray); #ifdef __cplusplus } diff --git a/jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch b/jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch new file mode 100644 index 000000000..bdc202bf6 --- /dev/null +++ b/jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch @@ -0,0 +1,53 @@ +From af6770b505a32b2c4eab2036d2509dec4b137f28 Mon Sep 17 00:00:00 2001 +From: Junqiu Lei +Date: Tue, 23 Apr 2024 17:18:56 -0700 +Subject: [PATCH] Custom patch to support range search params + +Signed-off-by: Junqiu Lei +--- + faiss/IndexIDMap.cpp | 28 ++++++++++++++++++++++++---- + 1 file changed, 24 insertions(+), 4 deletions(-) + +diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp +index 3f375e7b..11f3a847 100644 +--- a/faiss/IndexIDMap.cpp ++++ b/faiss/IndexIDMap.cpp +@@ -176,11 +176,31 @@ void IndexIDMapTemplate::range_search( + RangeSearchResult* result, + const SearchParameters* params) const { + if (params) { +- SearchParameters internal_search_parameters; +- IDSelectorTranslated id_selector_translated(id_map, params->sel); +- internal_search_parameters.sel = &id_selector_translated; ++ IDSelectorTranslated this_idtrans(this->id_map, nullptr); ++ ScopedSelChange sel_change; ++ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr); ++ ScopedGrpChange grp_change; ++ ++ if (params->sel) { ++ auto idtrans = dynamic_cast(params->sel); ++ ++ if (!idtrans) { ++ auto params_non_const = const_cast(params); ++ this_idtrans.sel = params->sel; ++ sel_change.set(params_non_const, &this_idtrans); ++ } ++ } ++ ++ if (params->grp) { ++ auto idtrans = dynamic_cast(params->grp); + +- index->range_search(n, x, radius, result, &internal_search_parameters); ++ if (!idtrans) { ++ auto params_non_const = const_cast(params); ++ this_idgrptrans.grp = params->grp; ++ grp_change.set(params_non_const, &this_idgrptrans); ++ } ++ } ++ index->range_search(n, x, radius, result, params); + } else { + index->range_search(n, x, radius, result); + } +-- +2.39.0 + diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 983cfa8a9..f05f088a1 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -27,6 +27,7 @@ #include #include #include +#include // Defines type of IDSelector enum FilterIdsSelectorType{ @@ -589,7 +590,12 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) { } jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, - jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) { + jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ) { + return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, nullptr, 0, parentIdsJ); +} + +jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, + jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } @@ -605,7 +611,69 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniU // The res will be freed by ~RangeSearchResult() in FAISS // The second parameter is always true, as lims is allocated by FAISS faiss::RangeSearchResult res(1, true); - indexReader->range_search(1, rawQueryVector, radiusJ, &res); + + if(filterIdsJ != nullptr) { + jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); + std::unique_ptr idSelector; + if(filterIdsTypeJ == BITMAP) { + idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); + } else { + faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); + idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices)); + } + faiss::SearchParameters *searchParameters; + faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + hnswParams.sel = idSelector.get(); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } + searchParameters = &hnswParams; + } else { + auto ivfReader = dynamic_cast(indexReader->index); + auto ivfFlatReader = dynamic_cast(indexReader->index); + if(ivfReader || ivfFlatReader) { + ivfParams.sel = idSelector.get(); + searchParameters = &ivfParams; + } + } + try { + indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + throw; + } + } else { + faiss::SearchParameters *searchParameters = nullptr; + faiss::SearchParametersHNSW hnswParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader!= nullptr && parentIdsJ != nullptr) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + searchParameters = &hnswParams; + } + try { + indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT); + throw; + } + } // lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries), // lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN, diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index ab2a37e84..0aa51987d 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -194,11 +194,26 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultWindowJ) + jfloat radiusJ, jint maxResultWindowJ, + jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ); + return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, parentIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls, + jlong indexPointerJ, + jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultWindowJ, + jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) +{ + try { + return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, + maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 07b34976f..e9316dcc2 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -628,7 +628,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { // Define query data float radius = 100000.0; - int numQueries = 2; + int numQueries = 100; std::vector> queries; for (int i = 0; i < numQueries; i++) { @@ -659,7 +659,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow))); + reinterpret_cast(&query), radius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -684,7 +684,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ // Define query data float radius = 100000.0; - int numQueries = 2; + int numQueries = 100; std::vector> queries; for (int i = 0; i < numQueries; i++) { @@ -715,7 +715,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow))); + reinterpret_cast(&query), radius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -728,3 +728,151 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ } } } + +TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 200; + int dim = 2; + std::vector ids = test_util::Range(numIds); + std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 100; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + int num_bits = test_util::bits2words(164); + std::vector bitmap(num_bits,0); + std::vector filterIds; + + for (int64_t i = 154; i < 163; i++) { + filterIds.push_back(i); + test_util::setBitSet(i, bitmap.data(), bitmap.size()); + } + std::unordered_set filterIdSet(filterIds.begin(), filterIds.end()); + + int maxResultWindow = 20000; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearchWithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow, + reinterpret_cast(&bitmap), 0, nullptr))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + ASSERT_TRUE(results->size() <= filterIds.size()); + for (const auto& pairPtr : *results) { + auto it = filterIdSet.find(pairPtr->first); + ASSERT_NE(it, filterIdSet.end()); + } + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} + +TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 100; + std::vector ids; + std::vector vectors; + std::vector parentIds; + int dim = 2; + for (int64_t i = 1; i < numIds + 1; i++) { + if (i % 10 == 0) { + parentIds.push_back(i); + continue; + } + ids.push_back(i); + for (int j = 0; j < dim; j++) { + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 1; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(&parentIds))) + .WillRepeatedly(Return(parentIds.size())); + + int maxResultWindow = 10000; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearchWithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow, nullptr, 0, + reinterpret_cast(&parentIds)))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + // Result should be one for each group + std::set idSet; + for (const auto& pairPtr : *results) { + idSet.insert(pairPtr->first / 10); + } + ASSERT_NE(0, idSet.size()); + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 7d872d60f..8eb03f04d 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -135,4 +135,5 @@ public class KNNConstants { public static int MAX_DISTANCE_COMPUTATIONS = 2048000; public static final Float DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO = 0.95f; + public static final String MIN_SCORE = "min_score"; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 32b7c05e8..532b9eb78 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -65,8 +65,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; - private Float max_distance = null; - private Float min_score = null; + private Float maxDistance = null; + private Float minScore = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; @@ -99,7 +99,7 @@ public KNNQueryBuilder k(Integer k) { if (k == null) { throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); } - validateSingleQueryType(k, max_distance, min_score); + validateSingleQueryType(k, maxDistance, minScore); if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); } @@ -108,33 +108,33 @@ public KNNQueryBuilder k(Integer k) { } /** - * Builder method for max_distance + * Builder method for maxDistance * - * @param max_distance the max_distance threshold for the nearest neighbours + * @param maxDistance the maxDistance threshold for the nearest neighbours */ - public KNNQueryBuilder maxDistance(Float max_distance) { - if (max_distance == null) { - throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set"); + public KNNQueryBuilder maxDistance(Float maxDistance) { + if (maxDistance == null) { + throw new IllegalArgumentException("[" + NAME + "] requires maxDistance to be set"); } - validateSingleQueryType(k, max_distance, min_score); - this.max_distance = max_distance; + validateSingleQueryType(k, maxDistance, minScore); + this.maxDistance = maxDistance; return this; } /** - * Builder method for min_score + * Builder method for minScore * - * @param min_score the min_score threshold for the nearest neighbours + * @param minScore the minScore threshold for the nearest neighbours */ - public KNNQueryBuilder minScore(Float min_score) { - if (min_score == null) { - throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set"); + public KNNQueryBuilder minScore(Float minScore) { + if (minScore == null) { + throw new IllegalArgumentException("[" + NAME + "] requires minScore to be set"); } - validateSingleQueryType(k, max_distance, min_score); - if (min_score <= 0) { - throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0"); + validateSingleQueryType(k, maxDistance, minScore); + if (minScore <= 0) { + throw new IllegalArgumentException("[" + NAME + "] requires minScore greater than 0"); } - this.min_score = min_score; + this.minScore = minScore; return this; } @@ -181,8 +181,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.k = k; this.filter = filter; this.ignoreUnmapped = false; - this.max_distance = null; - this.min_score = null; + this.maxDistance = null; + this.minScore = null; } public static void initialize(ModelDao modelDao) { @@ -218,10 +218,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException { ignoreUnmapped = in.readOptionalBoolean(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - max_distance = in.readOptionalFloat(); + maxDistance = in.readOptionalFloat(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - min_score = in.readOptionalFloat(); + minScore = in.readOptionalFloat(); } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); @@ -233,8 +233,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer k = null; - Float max_distance = null; - Float min_score = null; + Float maxDistance = null; + Float minScore = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -264,9 +264,9 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + maxDistance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + minScore = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -296,7 +296,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - validateSingleQueryType(k, max_distance, min_score); + validateSingleQueryType(k, maxDistance, minScore); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) @@ -305,10 +305,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (k != null) { knnQueryBuilder.k(k); - } else if (max_distance != null) { - knnQueryBuilder.maxDistance(max_distance); - } else if (min_score != null) { - knnQueryBuilder.minScore(min_score); + } else if (maxDistance != null) { + knnQueryBuilder.maxDistance(maxDistance); + } else if (minScore != null) { + knnQueryBuilder.minScore(minScore); } return knnQueryBuilder; @@ -324,10 +324,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(ignoreUnmapped); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(max_distance); + out.writeOptionalFloat(maxDistance); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(min_score); + out.writeOptionalFloat(minScore); } } @@ -350,11 +350,11 @@ public int getK() { } public float getMaxDistance() { - return this.max_distance; + return this.maxDistance; } public float getMinScore() { - return this.min_score; + return this.minScore; } public QueryBuilder getFilter() { @@ -385,14 +385,14 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } - if (max_distance != null) { - builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance); + if (maxDistance != null) { + builder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance); } if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } - if (min_score != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score); + if (minScore != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); } printBoostAndQueryName(builder); builder.endObject(); @@ -436,18 +436,18 @@ protected Query doToQuery(QueryShardContext context) { // Currently, k-NN supports distance and score types radial search // We need transform distance/score to right type of engine required radius. Float radius = null; - if (this.max_distance != null) { - if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + if (this.maxDistance != null) { + if (this.maxDistance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType); } - radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType); + radius = knnEngine.distanceToRadialThreshold(this.maxDistance, spaceType); } - if (this.min_score != null) { - if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + if (this.minScore != null) { + if (this.minScore > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); } - radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType); + radius = knnEngine.scoreToRadialThreshold(this.minScore, spaceType); } if (fieldDimension != vector.length) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 8939a569e..bac8c03d4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -294,7 +294,10 @@ private Map doANNSearch(final LeafReaderContext context, final B knnQuery.getQueryVector(), knnQuery.getRadius(), knnEngine, - knnQuery.getContext().getMaxResultWindow() + knnQuery.getContext().getMaxResultWindow(), + filterIds, + filterType.getValue(), + parentIds ); } } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index b59ac4bcf..53980bbb7 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -191,6 +191,28 @@ public static native KNNQueryResult[] queryIndexWithFilter( @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); + /** + * Range search index with filter + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param indexMaxResultWindow maximum number of results to return + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType type of filter ids + * @param parentIds list of parent doc ids when the knn field is a nested field + * @return KNNQueryResult array of neighbors within radius + */ + public static native KNNQueryResult[] rangeSearchIndexWithFilter( + long indexPointer, + float[] queryVector, + float radius, + int indexMaxResultWindow, + long[] filteredIds, + int filterIdsType, + int[] parentIds + ); + /** * Range search index * @@ -198,7 +220,14 @@ public static native KNNQueryResult[] queryIndexWithFilter( * @param queryVector vector to be used for query * @param radius search within radius threshold * @param indexMaxResultWindow maximum number of results to return + * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of neighbors within radius */ - public static native KNNQueryResult[] rangeSearchIndex(long indexPointer, float[] queryVector, float radius, int indexMaxResultWindow); + public static native KNNQueryResult[] rangeSearchIndex( + long indexPointer, + float[] queryVector, + float radius, + int indexMaxResultWindow, + int[] parentIds + ); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index e846f02d1..20c418819 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -271,6 +271,9 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) * @param radius search within radius threshold * @param knnEngine engine to query index * @param indexMaxResultWindow maximum number of results to return + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType how to filter ids: Batch or BitMap + * @param parentIds parent ids of the vectors * @return KNNQueryResult array of neighbors within radius */ public static KNNQueryResult[] radiusQueryIndex( @@ -278,10 +281,24 @@ public static KNNQueryResult[] radiusQueryIndex( float[] queryVector, float radius, KNNEngine knnEngine, - int indexMaxResultWindow + int indexMaxResultWindow, + long[] filteredIds, + int filterIdsType, + int[] parentIds ) { if (KNNEngine.FAISS == knnEngine) { - return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow); + if (ArrayUtils.isNotEmpty(filteredIds)) { + return FaissService.rangeSearchIndexWithFilter( + indexPointer, + queryVector, + radius, + indexMaxResultWindow, + filteredIds, + filterIdsType, + parentIds + ); + } + return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow, parentIds); } throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 49f74cb7c..6e230dd86 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -23,6 +23,7 @@ import org.opensearch.client.ResponseException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNResult; @@ -34,6 +35,7 @@ import java.io.IOException; import java.net.URL; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; @@ -91,7 +93,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { String indexName = "test-index-1"; String fieldName = "test-field-1"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -108,10 +110,10 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -180,7 +182,7 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { String indexName = "test-index-1"; String fieldName = "test-field-1"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -197,10 +199,10 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -279,7 +281,7 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -296,10 +298,10 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -330,7 +332,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); float distance = 300000000000f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -338,7 +340,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -355,10 +357,10 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -390,7 +392,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF float score = 0.00001f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -398,7 +400,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.INNER_PRODUCT; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -415,10 +417,10 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -450,14 +452,14 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe float score = 5f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null); // Delete index deleteKNNIndex(INDEX_NAME); } @SneakyThrows - public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { + public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; String fieldName = "test-field"; String trainingIndexName = "training-index"; @@ -535,7 +537,7 @@ public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsH float distance = 300000000000f; - validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType); + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType, null); // Delete index deleteKNNIndex(indexName); @@ -554,6 +556,32 @@ public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsH fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testRadialQuery_withFilter_thenSuccess() { + setupKNNIndexForFilterQuery(); + + final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + List expectedDocIds = Arrays.asList(DOC_ID_3); + + float distance = 15f; + List> queryResult = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + searchVector, + distance, + null, + SpaceType.L2, + termQueryBuilder + ); + + assertEquals(1, queryResult.get(0).size()); + assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId()); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; @@ -670,7 +698,7 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { String indexName = "test-index-hnsw-sqfp16"; String fieldName = "test-field-hnsw-sqfp16"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; Random random = new Random(); SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; @@ -690,10 +718,10 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -784,7 +812,7 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { String indexName = "test-index-sqfp16"; String fieldName = "test-field-sqfp16"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; Random random = new Random(); SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; @@ -803,10 +831,10 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -886,7 +914,7 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then String indexName = "test-index-sqfp16-clip-fp16"; String fieldName = "test-field-sqfp16"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); Random random = new Random(); List mValues = ImmutableList.of(16, 32, 64, 128); @@ -903,10 +931,10 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -1351,7 +1379,7 @@ public void testDocUpdate() throws IOException { String fieldName = "test-field-1"; Integer dimension = 2; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; // Create an index @@ -1362,9 +1390,9 @@ public void testDocUpdate() throws IOException { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1387,7 +1415,7 @@ public void testDocDeletion() throws IOException { String fieldName = "test-field-1"; Integer dimension = 2; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; // Create an index @@ -1398,9 +1426,9 @@ public void testDocDeletion() throws IOException { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1576,9 +1604,9 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore .field("type", "knn_vector") .field("dimension", 2) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1627,9 +1655,9 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .field("type", "knn_vector") .field("dimension", 3) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1689,14 +1717,16 @@ private void validateGraphEviction() throws Exception { fail("Graphs are not getting evicted"); } - private void validateRadiusSearchResults( + private List> validateRadiusSearchResults( String indexName, String fieldName, float[][] queryVectors, Float distanceThreshold, Float scoreThreshold, - final SpaceType spaceType + final SpaceType spaceType, + TermQueryBuilder filterQuery ) throws IOException, ParseException { + List> queryResults = new ArrayList<>(); for (float[] queryVector : queryVectors) { XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); queryBuilder.startObject("knn"); @@ -1709,6 +1739,9 @@ private void validateRadiusSearchResults( } else { throw new IllegalArgumentException("Invalid threshold"); } + if (filterQuery != null) { + queryBuilder.field("filter", filterQuery); + } queryBuilder.endObject(); queryBuilder.endObject(); queryBuilder.endObject().endObject(); @@ -1727,6 +1760,8 @@ private void validateRadiusSearchResults( throw new IllegalArgumentException("Invalid space type"); } } + queryResults.add(knnResults); } + return queryResults; } } diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index 288f06caf..20b87708f 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -32,6 +32,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.PATH; @@ -152,7 +153,64 @@ public void testNestedSearchWithFaiss_whenDoingExactSearch_thenReturnCorrectResu updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 100)); Float[] queryVector = { 3f, 3f, 3f }; - Response response = queryNestedField(INDEX_NAME, 3, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + Response response = queryNestedField(INDEX_NAME, 3, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE, null); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 1, 1, 1 + * ], + * "min_score": 0.00001, + * "filter": { + * "term": { + * "parking": "true" + * } + * } + * } + * } + * } + * } + * } + * } + * + */ + @SneakyThrows + public void testNestedWithFaiss_whenFilter_whenDoRadialSearch_thenReturnCorrectResults() { + createKnnIndex(3, KNNEngine.FAISS.getName()); + + for (int i = 1; i < 4; i++) { + float value = (float) i; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addVectors( + FIELD_NAME_VECTOR, + new Float[] { value, value, value }, + new Float[] { value, value, value }, + new Float[] { value, value, value } + ) + .addTopLevelField(FIELD_NAME_PARKING, i % 2 == 1 ? FIELD_VALUE_TRUE : FIELD_VALUE_FALSE) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + Float[] queryVector = { 3f, 3f, 3f }; + Float minScore = 0.00001f; + Response response = queryNestedField(INDEX_NAME, null, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE, minScore); + String entity = EntityUtils.toString(response.getEntity()); List docIds = parseIds(entity); assertEquals(2, docIds.size()); @@ -215,22 +273,29 @@ private void createKnnIndex(final int dimension, final String engine) throws Exc } private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException { - return queryNestedField(index, k, vector, null, null); + return queryNestedField(index, k, vector, null, null, null); } private Response queryNestedField( final String index, - final int k, + final Integer k, final Object[] vector, final String filterName, - final String filterValue + final String filterValue, + final Float minScore ) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); builder.startObject(TYPE_NESTED); builder.field(PATH, FIELD_NAME_NESTED); builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); builder.field(VECTOR, vector); - builder.field(K, k); + if (minScore != null) { + builder.field(MIN_SCORE, minScore); + } else if (k != null) { + builder.field(K, k); + } else { + throw new IllegalArgumentException("k or minScore must be provided in the query"); + } if (filterName != null && filterValue != null) { builder.startObject(FIELD_FILTER); builder.startObject(FIELD_TERM); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index c8d9fa400..adf985b39 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -699,8 +699,9 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final float[] queryVector = new float[] { 0.1f, 0.3f }; final float radius = 0.5f; final int maxResults = 1000; - jniServiceMockedStatic.when(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt(), any(), anyInt(), any()) + ).thenReturn(getKNNQueryResults()); KNNQuery.Context context = mock(KNNQuery.Context.class); when(context.getMaxResultWindow()).thenReturn(maxResults); @@ -742,7 +743,9 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); assertNotNull(knnScorer); - jniServiceMockedStatic.verify(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())); + jniServiceMockedStatic.verify( + () -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt(), any(), anyInt(), any()) + ); final DocIdSetIterator docIdSetIterator = knnScorer.iterator();