diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 69a2bc806..7d872d60f 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -133,4 +133,6 @@ public class KNNConstants { // Please refer this github issue for more details for choosing this value: // https://github.com/opensearch-project/k-NN/issues/1049#issuecomment-1694741092 public static int MAX_DISTANCE_COMPUTATIONS = 2048000; + + public static final Float DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO = 0.95f; } diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index cd32ac4f3..db8084864 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -118,7 +119,13 @@ private static Query getFloatVectorSimilarityQuery( final float resultSimilarity, final Query filterQuery ) { - return new FloatVectorSimilarityQuery(fieldName, floatVector, resultSimilarity, filterQuery); + return new FloatVectorSimilarityQuery( + fieldName, + floatVector, + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity, + resultSimilarity, + filterQuery + ); } /** @@ -131,6 +138,12 @@ private static Query getByteVectorSimilarityQuery( final float resultSimilarity, final Query filterQuery ) { - return new ByteVectorSimilarityQuery(fieldName, byteVector, resultSimilarity, filterQuery); + return new ByteVectorSimilarityQuery( + fieldName, + byteVector, + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity, + resultSimilarity, + filterQuery + ); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 1922e5a08..134a5ab76 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -52,6 +52,7 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -394,7 +395,12 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2))); + float resultSimilarity = KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2); + + assertTrue(query.toString().contains("resultSimilarity=" + resultSimilarity)); + assertTrue( + query.toString().contains("traversalSimilarity=" + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity) + ); } public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() {