diff --git a/CHANGELOG.md b/CHANGELOG.md index 59dd618d8..feb1411e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.11...2.x) ### Features +* Add parent join support for lucene knn [#1182](https://github.com/opensearch-project/k-NN/pull/1182) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 59732cca0..184067c35 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -15,6 +15,14 @@ public class KNNConstants { public static final String NAME = "name"; public static final String PARAMETERS = "parameters"; public static final String METHOD_HNSW = "hnsw"; + public static final String TYPE = "type"; + public static final String TYPE_NESTED = "nested"; + public static final String PATH = "path"; + public static final String QUERY = "query"; + public static final String KNN = "knn"; + public static final String VECTOR = "vector"; + public static final String K = "k"; + public static final String TYPE_KNN_VECTOR = "knn_vector"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; public static final String METHOD_PARAMETER_M = "m"; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index b05098f28..c073450af 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -14,6 +14,9 @@ import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; @@ -86,10 +89,12 @@ public static Query create(CreateQueryRequest createQueryRequest) { return new KNNQuery(fieldName, vector, k, indexName); } + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); if (VectorDataType.BYTE == vectorDataType) { - return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery); + return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); } else if (VectorDataType.FLOAT == vectorDataType) { - return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery); + return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter); } else { throw new IllegalArgumentException( String.format( @@ -102,38 +107,40 @@ public static Query create(CreateQueryRequest createQueryRequest) { } } - private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) { - if (filterQuery != null) { - log.debug( - String.format( - Locale.ROOT, - "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", - indexName, - fieldName, - k - ) - ); + /** + * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery} + * which will dedupe search result per parent so that we can get k parent results at the end. + */ + private static Query getKnnByteVectorQuery( + final String fieldName, + final byte[] byteVector, + final int k, + final Query filterQuery, + final BitSetProducer parentFilter + ) { + if (parentFilter == null) { return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery); + } else { + return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnByteVectorQuery(fieldName, byteVector, k); } - private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) { - if (filterQuery != null) { - log.debug( - String.format( - Locale.ROOT, - "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", - indexName, - fieldName, - k - ) - ); + /** + * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery} + * which will dedupe search result per parent so that we can get k parent results at the end. + */ + private static Query getKnnFloatVectorQuery( + final String fieldName, + final float[] floatVector, + final int k, + final Query filterQuery, + final BitSetProducer parentFilter + ) { + if (parentFilter == null) { return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery); + } else { + return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnFloatVectorQuery(fieldName, floatVector, k); } private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { @@ -181,6 +188,8 @@ static class CreateQueryRequest { @Getter private int k; // can be null in cases filter not passed with the knn query + @Getter + private BitSetProducer parentFilter; private QueryBuilder filter; // can be null in cases filter not passed with the knn query private QueryShardContext context; diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java new file mode 100644 index 000000000..dce006b50 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -0,0 +1,202 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import org.apache.http.util.EntityUtils; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.K; +import static org.opensearch.knn.common.KNNConstants.KNN; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.PATH; +import static org.opensearch.knn.common.KNNConstants.QUERY; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED; +import static org.opensearch.knn.common.KNNConstants.VECTOR; + +public class NestedSearchIT extends KNNRestTestCase { + private static final String INDEX_NAME = "test-index-nested-search"; + private static final String FIELD_NAME_NESTED = "test-nested"; + private static final String FIELD_NAME_VECTOR = "test-vector"; + private static final String PROPERTIES_FIELD = "properties"; + private static final int EF_CONSTRUCTION = 128; + private static final int M = 16; + private static final SpaceType SPACE_TYPE = SpaceType.L2; + + @After + @SneakyThrows + public final void cleanUp() { + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() { + createKnnIndex(2, KNNEngine.LUCENE.getName()); + + String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f }) + .build(); + addNestedKnnDoc(INDEX_NAME, "1", doc1); + + String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f }) + .build(); + addNestedKnnDoc(INDEX_NAME, "2", doc2); + + Float[] queryVector = { 1f, 1f }; + Response response = queryNestedField(INDEX_NAME, 2, queryVector); + + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + EntityUtils.toString(response.getEntity()) + ).map().get("hits")).get("hits"); + assertEquals(2, hits.size()); + } + + /** + * { + * "properties": { + * "test-nested": { + * "type": "nested", + * "properties": { + * "test-vector": { + * "type": "knn_vector", + * "dimension": 3, + * "method": { + * "name": "hnsw", + * "space_type": "l2", + * "engine": "lucene", + * "parameters": { + * "ef_construction": 128, + * "m": 24 + * } + * } + * } + * } + * } + * } + * } + */ + private void createKnnIndex(final int dimension, final String engine) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_NESTED) + .field(TYPE, TYPE_NESTED) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_VECTOR) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SPACE_TYPE) + .field(KNN_ENGINE, engine) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, M) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, mapping); + } + + @SneakyThrows + private void ingestTestData() { + String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f }) + .build(); + addNestedKnnDoc(INDEX_NAME, "1", doc1); + + String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f }) + .build(); + addNestedKnnDoc(INDEX_NAME, "2", doc2); + } + + private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + + request.setJsonEntity(document); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); + builder.field(VECTOR, vector); + builder.field(K, k); + builder.endObject().endObject().endObject().endObject().endObject().endObject(); + + Request request = new Request("POST", "/" + index + "/_search"); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + return response; + } + + private static class NestedKnnDocBuilder { + private XContentBuilder builder; + + public NestedKnnDocBuilder(final String fieldName) throws IOException { + builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName); + } + + public static NestedKnnDocBuilder create(final String fieldName) throws IOException { + return new NestedKnnDocBuilder(fieldName); + } + + public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException { + for (Object[] vector : vectors) { + builder.startObject(); + builder.field(fieldName, vector); + builder.endObject(); + } + return this; + } + + public String build() throws IOException { + builder.endArray().endObject(); + return builder.toString(); + } + + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 4dccfd087..a6b915a85 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -9,12 +9,16 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.mockito.Mockito; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.util.Arrays; @@ -33,6 +37,7 @@ public class KNNQueryFactoryTests extends KNNTestCase { private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE)); private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; + private final byte[] testByteQueryVector = new byte[testQueryDimension]; private final String testIndexName = "test-index"; private final String testFieldName = "test-field"; private final int testK = 10; @@ -69,7 +74,7 @@ public void testCreateLuceneDefaultQuery() { testK, DEFAULT_VECTOR_DATA_TYPE_FIELD ); - assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); + assertEquals(KnnFloatVectorQuery.class, query.getClass()); } } @@ -92,7 +97,7 @@ public void testCreateLuceneQueryWithFilter() { .filter(FILTER_QUERY_BUILDER) .build(); Query query = KNNQueryFactory.create(createQueryRequest); - assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); + assertEquals(KnnFloatVectorQuery.class, query.getClass()); } } @@ -120,4 +125,35 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { assertEquals(testK, ((KNNQuery) query).getK()); assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery()); } + + public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() { + validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class); + validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class); + } + + private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .byteVector(testByteQueryVector) + .vectorDataType(type) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertEquals(expectedQueryClass, query.getClass()); + } + } }