From 089b360cf6098f5dc0a3729574a004a79a73b664 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Fri, 12 Apr 2024 09:51:13 -0700 Subject: [PATCH] Rename radial search parameters score and distance to min_score and max_distance (#1609) * Rename radial search parameters score and distance to min_score and max_distance Signed-off-by: Junqiu Lei --- .../knn/index/query/KNNQueryBuilder.java | 108 +++++++------- .../org/opensearch/knn/index/FaissIT.java | 4 +- .../opensearch/knn/index/LuceneEngineIT.java | 4 +- .../knn/index/query/KNNQueryBuilderTests.java | 134 ++++++++++++------ 4 files changed, 148 insertions(+), 102 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 78ddb532d..7d3667ac0 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -51,8 +51,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); - public static final ParseField DISTANCE_FIELD = new ParseField("distance"); - public static final ParseField SCORE_FIELD = new ParseField("score"); + public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); + public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); public static final int K_MAX = 10000; /** * The name for the knn query @@ -64,17 +64,17 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; - private Float distance = null; - private Float score = null; + private Float max_distance = null; + private Float min_score = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; /** - * Constructs a new query with the given field name and vector + * Constructs a new query with the given field name and vector * * @param fieldName Name of the field * @param vector Array of floating points - */ + */ public KNNQueryBuilder(String fieldName, float[] vector) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); @@ -98,7 +98,7 @@ public KNNQueryBuilder k(Integer k) { if (k == null) { throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); } - validateSingleQueryType(k, distance, score); + validateSingleQueryType(k, max_distance, min_score); if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); } @@ -107,33 +107,33 @@ public KNNQueryBuilder k(Integer k) { } /** - * Builder method for distance + * Builder method for max_distance * - * @param distance the distance threshold for the nearest neighbours + * @param max_distance the max_distance threshold for the nearest neighbours */ - public KNNQueryBuilder distance(Float distance) { - if (distance == null) { - throw new IllegalArgumentException("[" + NAME + "] requires distance to be set"); + public KNNQueryBuilder maxDistance(Float max_distance) { + if (max_distance == null) { + throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set"); } - validateSingleQueryType(k, distance, score); - this.distance = distance; + validateSingleQueryType(k, max_distance, min_score); + this.max_distance = max_distance; return this; } /** - * Builder method for score + * Builder method for min_score * - * @param score the score threshold for the nearest neighbours + * @param min_score the min_score threshold for the nearest neighbours */ - public KNNQueryBuilder score(Float score) { - if (score == null) { - throw new IllegalArgumentException("[" + NAME + "] requires score to be set"); + public KNNQueryBuilder minScore(Float min_score) { + if (min_score == null) { + throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set"); } - validateSingleQueryType(k, distance, score); - if (score <= 0) { - throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0"); + validateSingleQueryType(k, max_distance, min_score); + if (min_score <= 0) { + throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0"); } - this.score = score; + this.min_score = min_score; return this; } @@ -180,8 +180,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.k = k; this.filter = filter; this.ignoreUnmapped = false; - this.distance = null; - this.score = null; + this.max_distance = null; + this.min_score = null; } public static void initialize(ModelDao modelDao) { @@ -217,10 +217,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException { ignoreUnmapped = in.readOptionalBoolean(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - distance = in.readOptionalFloat(); + max_distance = in.readOptionalFloat(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - score = in.readOptionalFloat(); + min_score = in.readOptionalFloat(); } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); @@ -232,8 +232,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer k = null; - Float distance = null; - Float score = null; + Float max_distance = null; + Float min_score = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -262,10 +262,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); - } else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); - } else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -295,7 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - validateSingleQueryType(k, distance, score); + validateSingleQueryType(k, max_distance, min_score); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) @@ -304,10 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (k != null) { knnQueryBuilder.k(k); - } else if (distance != null) { - knnQueryBuilder.distance(distance); - } else if (score != null) { - knnQueryBuilder.score(score); + } else if (max_distance != null) { + knnQueryBuilder.maxDistance(max_distance); + } else if (min_score != null) { + knnQueryBuilder.minScore(min_score); } return knnQueryBuilder; @@ -323,10 +323,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(ignoreUnmapped); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(distance); + out.writeOptionalFloat(max_distance); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(score); + out.writeOptionalFloat(min_score); } } @@ -348,12 +348,12 @@ public int getK() { return this.k; } - public float getDistance() { - return this.distance; + public float getMaxDistance() { + return this.max_distance; } - public float getScore() { - return this.score; + public float getMinScore() { + return this.min_score; } public QueryBuilder getFilter() { @@ -384,14 +384,14 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } - if (distance != null) { - builder.field(DISTANCE_FIELD.getPreferredName(), distance); + if (max_distance != null) { + builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance); } if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } - if (score != null) { - builder.field(SCORE_FIELD.getPreferredName(), score); + if (min_score != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score); } printBoostAndQueryName(builder); builder.endObject(); @@ -435,18 +435,18 @@ protected Query doToQuery(QueryShardContext context) { // Currently, k-NN supports distance and score types radial search // We need transform distance/score to right type of engine required radius. Float radius = null; - if (this.distance != null) { - if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + if (this.max_distance != null) { + if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType); } - radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType); + radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType); } - if (this.score != null) { - if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + if (this.min_score != null) { + if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); } - radius = knnEngine.scoreToRadialThreshold(this.score, spaceType); + radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType); } if (fieldDimension != vector.length) { diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 16eb1a4c3..bcefeb7f4 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1702,9 +1702,9 @@ private void validateRadiusSearchResults( queryBuilder.startObject(fieldName); queryBuilder.field("vector", queryVector); if (distanceThreshold != null) { - queryBuilder.field("distance", distanceThreshold); + queryBuilder.field("max_distance", distanceThreshold); } else if (scoreThreshold != null) { - queryBuilder.field("score", scoreThreshold); + queryBuilder.field("min_score", scoreThreshold); } else { throw new IllegalArgumentException("Invalid threshold"); } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index f721606e1..ab55741d3 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -629,9 +629,9 @@ private void validateRadiusSearchResults( builder.startObject(FIELD_NAME); builder.field("vector", searchVectors[i]); if (distanceThreshold != null) { - builder.field("distance", distanceThreshold); + builder.field("max_distance", distanceThreshold); } else if (scoreThreshold != null) { - builder.field("score", scoreThreshold); + builder.field("min_score", scoreThreshold); } else { throw new IllegalArgumentException("Either distance or score must be provided"); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 8998dce69..1922e5a08 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -60,8 +60,8 @@ public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; - private static final Float DISTANCE = 1.0f; - private static final Float SCORE = 0.5f; + private static final Float MAX_DISTANCE = 1.0f; + private static final Float MIN_SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -89,25 +89,25 @@ public void testInvalidDistance() { /** * null distance */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(null)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(null)); } public void testInvalidScore() { float[] queryVector = { 1.0f, 1.0f }; /** - * null score + * null min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(null)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(null)); /** - * negative score + * negative min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(-1.0f)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(-1.0f)); /** - * score = 0 + * min_score = 0 */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(0.0f)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(0.0f)); } public void testEmptyVector() { @@ -127,13 +127,13 @@ public void testEmptyVector() { * null query vector with distance */ float[] queryVector2 = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).distance(DISTANCE)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).maxDistance(MAX_DISTANCE)); /** * empty query vector with distance */ float[] queryVector3 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).distance(DISTANCE)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).maxDistance(MAX_DISTANCE)); } public void testFromXContent() throws Exception { @@ -154,12 +154,12 @@ public void testFromXContent() throws Exception { public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getDistance()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -170,12 +170,12 @@ public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSuccee public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MAX_DISTANCE); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -213,12 +213,12 @@ public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_ knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getDistance()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); builder.endObject(); builder.endObject(); @@ -235,12 +235,12 @@ public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_the knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); builder.endObject(); builder.endObject(); @@ -294,7 +294,7 @@ public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_ builder.startObject(); builder.startObject(FIELD_NAME); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), DISTANCE); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -382,7 +382,7 @@ public void testDoToQuery_Normal() throws Exception { public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -394,12 +394,12 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(DISTANCE, SpaceType.L2))); + assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2))); } public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -417,7 +417,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -441,7 +441,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -463,7 +463,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -487,7 +487,53 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + + public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(negativeDistance, query.getRadius(), 0); + } + + public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -527,7 +573,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -545,7 +591,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -628,7 +674,7 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -653,14 +699,14 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertEquals(knnQueryBuilder.getDistance(), query.getRadius(), 0); + assertEquals(knnQueryBuilder.getMaxDistance(), query.getRadius(), 0); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -685,7 +731,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertEquals(1 / knnQueryBuilder.getScore() - 1, query.getRadius(), 0); + assertEquals(1 / knnQueryBuilder.getMinScore() - 1, query.getRadius(), 0); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } @@ -764,12 +810,12 @@ public void testSerialization() throws Exception { assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, SCORE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE); } private void assertSerialization( @@ -801,9 +847,9 @@ private void assertSerialization( if (k != null) { assertEquals(k.intValue(), deserializedKnnQueryBuilder.getK()); } else if (distance != null) { - assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getDistance(), 0.0f); + assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getMaxDistance(), 0.0f); } else { - assertEquals(score.floatValue(), deserializedKnnQueryBuilder.getScore(), 0.0f); + assertEquals(score.floatValue(), deserializedKnnQueryBuilder.getMinScore(), 0.0f); } if (queryBuilderOptional.isPresent()) { assertNotNull(deserializedKnnQueryBuilder.getFilter()); @@ -823,12 +869,12 @@ private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBu : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); } else if (distance != null) { knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance); + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance); } else if (score != null) { knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score); + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score); } else { throw new IllegalArgumentException("Either k or distance must be provided"); } @@ -857,7 +903,7 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { SpaceType.L2, new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()) ); - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(MAX_DISTANCE); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy");