diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 180ce1b310..9d118e7375 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -15,13 +15,14 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.query.filtered.FilteredKNNIterator; +import org.opensearch.knn.index.query.filtered.KNNFloatQueryVector; +import org.opensearch.knn.index.query.filtered.NestedFilteredKNNIterator; +import org.opensearch.knn.index.query.filtered.PlainFilteredKNNIterator; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; @@ -44,7 +45,6 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.plugin.stats.KNNCounter; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; @@ -306,33 +306,23 @@ private Map doANNSearch(final LeafReaderContext context, final i } private Map doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) throws IOException { - final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); - final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - float[] queryVector = this.knnQuery.getQueryVector(); try { - final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); - final SpaceType spaceType = getSpaceType(fieldInfo); // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); - for (int filterId : filterIdsArray) { - int docId = values.advance(filterId); - final BytesRef value = values.binaryValue(); - final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - // Calculates a similarity score between the two vectors with a specified function. Higher similarity - // scores correspond to closer vectors. - float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector); - if (score > topDoc.score) { - topDoc.score = score; + FilteredKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsArray); + int docId; + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (iterator.score() > topDoc.score) { + topDoc.score = iterator.score(); topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we // have seen till now on top. topDoc = queue.updateTop(); } } + // If scores are negative we will remove them. // This is done, because there can be negative values in the Heap as we init the heap with Score as -INF. // If filterIds < k, the some values in heap can have a negative score. @@ -352,6 +342,23 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont return Collections.emptyMap(); } + private FilteredKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) + throws IOException { + final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); + final SpaceType spaceType = getSpaceType(fieldInfo); + return knnQuery.getParentsFilter() == null + ? new PlainFilteredKNNIterator(filterIdsArray, new KNNFloatQueryVector(knnQuery.getQueryVector()), values, spaceType) + : new NestedFilteredKNNIterator( + filterIdsArray, + new KNNFloatQueryVector(knnQuery.getQueryVector()), + values, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } + private Scorer convertSearchResponseToScorer(final Map docsToScore) throws IOException { final int maxDoc = Collections.max(docsToScore.keySet()) + 1; final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredKNNIterator.java new file mode 100644 index 0000000000..2eb461aedc --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredKNNIterator.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.index.BinaryDocValues; +import org.opensearch.knn.index.SpaceType; + +import java.io.IOException; + +/** + * Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene + * https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 + * + * The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. + */ +public abstract class FilteredKNNIterator { + // Array of doc ids to iterate + protected final int[] filterIdsArray; + protected float currentScore = Float.NEGATIVE_INFINITY; + protected final T queryVector; + protected final BinaryDocValues values; + protected final SpaceType spaceType; + protected int currentPos = 0; + + public FilteredKNNIterator(final int[] filterIdsArray, final T queryVector, final BinaryDocValues values, final SpaceType spaceType) { + this.filterIdsArray = filterIdsArray; + this.queryVector = queryVector; + this.values = values; + this.spaceType = spaceType; + } + + /** + * Advance to the next doc and update score + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next doc id + */ + abstract public int nextDoc() throws IOException; + + /** + * Return a score of current doc + * + * @return current score + */ + public float score() { + return currentScore; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/KNNFloatQueryVector.java b/src/main/java/org/opensearch/knn/index/query/filtered/KNNFloatQueryVector.java new file mode 100644 index 0000000000..58e8eb79c6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/KNNFloatQueryVector.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import lombok.NonNull; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; + +import java.io.ByteArrayInputStream; +import java.io.IOException; + +/** + * Implementation of KNNQueryVector with float data type + */ +public class KNNFloatQueryVector implements KNNQueryVector { + private final float[] queryVector; + + public KNNFloatQueryVector(final float[] queryVector) { + this.queryVector = queryVector; + } + + // Calculate similarity between queryVector and values in current position using given space type + public float score(@NonNull final BinaryDocValues values, @NonNull final SpaceType spaceType) throws IOException { + final BytesRef value = values.binaryValue(); + final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + return spaceType.getVectorSimilarityFunction().compare(queryVector, vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/KNNQueryVector.java b/src/main/java/org/opensearch/knn/index/query/filtered/KNNQueryVector.java new file mode 100644 index 0000000000..595130dcff --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/KNNQueryVector.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.index.BinaryDocValues; +import org.opensearch.knn.index.SpaceType; + +import java.io.IOException; + +/** + * Wrapper interface on knn query vector to hide its type(float, byte) and provide score for given doc value and space type + */ +public interface KNNQueryVector { + /** + * Return score of values using the spaceType + * + * @param values doc value + * @param spaceType space type to calculate score + * @return score of the doc value + * @throws IOException + */ + float score(final BinaryDocValues values, final SpaceType spaceType) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredKNNIterator.java new file mode 100644 index 0000000000..3a6a979324 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredKNNIterator.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.opensearch.knn.index.SpaceType; + +import java.io.IOException; + +/** + * This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc + * of which ID is set in parentBitSet and only return best child doc with the highest score. + */ +public class NestedFilteredKNNIterator extends FilteredKNNIterator { + private final BitSet parentBitSet; + private int currentParent = -1; + private int bestChild = -1; + + public NestedFilteredKNNIterator( + final int[] filterIdsArray, + final T queryVector, + final BinaryDocValues values, + final SpaceType spaceType, + final BitSet parentBitSet + ) { + super(filterIdsArray, queryVector, values, spaceType); + this.parentBitSet = parentBitSet; + } + + @Override + public int nextDoc() throws IOException { + if (currentPos >= filterIdsArray.length) { + return DocIdSetIterator.NO_MORE_DOCS; + } + currentScore = Float.NEGATIVE_INFINITY; + currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]); + while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent) { + int currentChild = filterIdsArray[currentPos]; + values.advance(currentChild); + final float score = queryVector.score(values, spaceType); + if (score > currentScore) { + bestChild = currentChild; + currentScore = score; + } + currentPos++; + } + + return bestChild; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/PlainFilteredKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/PlainFilteredKNNIterator.java new file mode 100644 index 0000000000..b1e3573b42 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/PlainFilteredKNNIterator.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.index.SpaceType; + +import java.io.IOException; + +/** + * Basic implementation of FilteredKNNIterator which iterate all doc IDs in filterIdsArray + */ +public class PlainFilteredKNNIterator extends FilteredKNNIterator { + private int currentDoc = -1; + + public PlainFilteredKNNIterator( + final int[] filterIdsArray, + final T queryVector, + final BinaryDocValues values, + final SpaceType spaceType + ) { + super(filterIdsArray, queryVector, values, spaceType); + } + + @Override + public int nextDoc() throws IOException { + if (currentPos >= filterIdsArray.length) { + return DocIdSetIterator.NO_MORE_DOCS; + } + currentDoc = values.advance(filterIdsArray[currentPos++]); + currentScore = queryVector.score(values, spaceType); + return currentDoc; + } +} diff --git a/src/test/java/org/opensearch/knn/common/Constants.java b/src/test/java/org/opensearch/knn/common/Constants.java new file mode 100644 index 0000000000..2580d2c9c1 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/Constants.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +public class Constants { + public static final String FIELD_FILTER = "filter"; + public static final String FIELD_TERM = "term"; +} diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index ccb3abb565..73fe6b72a4 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -10,6 +10,7 @@ import org.junit.After; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; @@ -18,7 +19,10 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.List; +import static org.opensearch.knn.common.Constants.FIELD_FILTER; +import static org.opensearch.knn.common.Constants.FIELD_TERM; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.K; import static org.opensearch.knn.common.KNNConstants.KNN; @@ -39,8 +43,11 @@ public class NestedSearchIT extends KNNRestTestCase { private static final String INDEX_NAME = "test-index-nested-search"; - private static final String FIELD_NAME_NESTED = "test-nested"; - private static final String FIELD_NAME_VECTOR = "test-vector"; + private static final String FIELD_NAME_NESTED = "test_nested"; + private static final String FIELD_NAME_VECTOR = "test_vector"; + private static final String FIELD_NAME_PARKING = "parking"; + private static final String FIELD_VALUE_TRUE = "true"; + private static final String FIELD_VALUE_FALSE = "false"; private static final String PROPERTIES_FIELD = "properties"; private static final int EF_CONSTRUCTION = 128; private static final int M = 16; @@ -98,13 +105,70 @@ public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() { assertEquals(2, parseTotalSearchHits(entity)); } + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 1, 1, 1 + * ], + * "k": 3, + * "filter": { + * "term": { + * "parking": "true" + * } + * } + * } + * } + * } + * } + * } + * } + * + */ + @SneakyThrows + public void testNestedSearchWithFaiss_whenDoingExactSearch_thenReturnCorrectResults() { + createKnnIndex(3, KNNEngine.FAISS.getName()); + + for (int i = 1; i < 4; i++) { + float value = (float) i; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addVectors( + FIELD_NAME_VECTOR, + new Float[] { value, value, value }, + new Float[] { value, value, value }, + new Float[] { value, value, value } + ) + .addTopLevelField(FIELD_NAME_PARKING, i % 2 == 1 ? FIELD_VALUE_TRUE : FIELD_VALUE_FALSE) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + + // Make it as an exact search by setting the threshold larger than size of filteredIds(6) + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 100)); + + Float[] queryVector = { 3f, 3f, 3f }; + Response response = queryNestedField(INDEX_NAME, 3, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + /** * { * "properties": { - * "test-nested": { + * "test_nested": { * "type": "nested", * "properties": { - * "test-vector": { + * "test_vector": { * "type": "knn_vector", * "dimension": 3, * "method": { @@ -152,12 +216,29 @@ private void createKnnIndex(final int dimension, final String engine) throws Exc } private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException { + return queryNestedField(index, k, vector, null, null); + } + + private Response queryNestedField( + final String index, + final int k, + final Object[] vector, + final String filterName, + final String filterValue + ) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); builder.startObject(TYPE_NESTED); builder.field(PATH, FIELD_NAME_NESTED); builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); builder.field(VECTOR, vector); builder.field(K, k); + if (filterName != null && filterValue != null) { + builder.startObject(FIELD_FILTER); + builder.startObject(FIELD_TERM); + builder.field(filterName, filterValue); + builder.endObject(); + builder.endObject(); + } builder.endObject().endObject().endObject().endObject().endObject().endObject(); Request request = new Request("POST", "/" + index + "/_search"); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 26179ecb6f..4d04b7edbf 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -597,7 +597,53 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { } @SneakyThrows - public void testANNWithParentsFilter_whenSet_thenBitSetIsPassedToJNI() { + public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { + SegmentReader reader = getMockedSegmentReader(); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + // We will have 0, 1 for filteredIds and 2 will be the parent id for both of them + final Scorer filterScorer = mock(Scorer.class); + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2)); + when(reader.maxDoc()).thenReturn(2); + + // Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result + final List vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f }); + final List byteRefs = vectors.stream() + .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) + .collect(Collectors.toList()); + final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(binaryDocValues.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1)); + when(binaryDocValues.advance(anyInt())).thenReturn(0, 1); + when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); + + // Parent ID 2 in bitset is 100 which is 4 + FixedBitSet parentIds = new FixedBitSet(new long[] { 4 }, 3); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(parentFilter.getBitSet(leafReaderContext)).thenReturn(parentIds); + + final Weight filterQueryWeight = mock(Weight.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, parentFilter); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + + // Execute + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Verify + final List expectedScores = vectors.stream() + .map(vector -> SpaceType.L2.getVectorSimilarityFunction().compare(QUERY_VECTOR, vector)) + .collect(Collectors.toList()); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertEquals(1, docIdSetIterator.nextDoc()); + assertEquals(expectedScores.get(1), knnScorer.score(), 0.01f); + assertEquals(NO_MORE_DOCS, docIdSetIterator.nextDoc()); + } + + @SneakyThrows + public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { SegmentReader reader = getMockedSegmentReader(); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); when(leafReaderContext.reader()).thenReturn(reader); @@ -632,11 +678,7 @@ private SegmentReader getMockedSegmentReader() { when(reader.maxDoc()).thenReturn(1); // Prepare live docs - int liveDocId = 0; - final Bits liveDocsBits = mock(Bits.class); - when(liveDocsBits.get(liveDocId)).thenReturn(true); - when(liveDocsBits.length()).thenReturn(1); - when(reader.getLiveDocs()).thenReturn(liveDocsBits); + when(reader.getLiveDocs()).thenReturn(null); // Prepare directory final Path path = mock(Path.class); @@ -662,10 +704,14 @@ private SegmentReader getMockedSegmentReader() { final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); - // Prepare fieldInfos - final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName()); + // Prepare fieldInfo + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); final FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.attributes()).thenReturn(attributesMap); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getName()).thenReturn(FIELD_NAME); + + // Prepare fieldInfos final FieldInfos fieldInfos = mock(FieldInfos.class); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(reader.getFieldInfos()).thenReturn(fieldInfos); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredFloatKNNScorerTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredFloatKNNScorerTests.java new file mode 100644 index 0000000000..906defee8a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredFloatKNNScorerTests.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class FilteredFloatKNNScorerTests extends KNNTestCase { + @SneakyThrows + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.L2; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; + final int[] filterIds = { 1, 2, 3 }; + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 14.0f, 15.0f, 16.0f }, + new float[] { 17.0f, 18.0f, 19.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + BinaryDocValues values = mock(BinaryDocValues.class); + final List byteRefs = dataVectors.stream() + .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) + .collect(Collectors.toList()); + when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + } + + // Execute and verify + PlainFilteredKNNIterator scorer = new PlainFilteredKNNIterator(filterIds, new KNNFloatQueryVector(queryVector), values, spaceType); + for (int i = 0; i < filterIds.length; i++) { + assertEquals(filterIds[i], scorer.nextDoc()); + assertEquals(expectedScores.get(i), (Float) scorer.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.nextDoc()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredFloatKNNScorerTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredFloatKNNScorerTests.java new file mode 100644 index 0000000000..e7bfff81b9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredFloatKNNScorerTests.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class NestedFilteredFloatKNNScorerTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.L2; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 17.0f, 18.0f, 19.0f }, + new float[] { 14.0f, 15.0f, 16.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + BinaryDocValues values = mock(BinaryDocValues.class); + final List byteRefs = dataVectors.stream() + .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) + .collect(Collectors.toList()); + when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + } + + // Execute and verify + NestedFilteredKNNIterator scorer = new NestedFilteredKNNIterator( + filterIds, + new KNNFloatQueryVector(queryVector), + values, + spaceType, + parentBitSet + ); + assertEquals(filterIds[0], scorer.nextDoc()); + assertEquals(expectedScores.get(0), scorer.score()); + assertEquals(filterIds[2], scorer.nextDoc()); + assertEquals(expectedScores.get(2), scorer.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.nextDoc()); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index d0ece655c1..3af8b7ec8d 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -721,6 +721,16 @@ protected int parseHits(String searchResponseBody) throws IOException { return ((List) responseMap.get("hits")).size(); } + protected List parseIds(String searchResponseBody) throws IOException { + @SuppressWarnings("unchecked") + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + searchResponseBody + ).map().get("hits")).get("hits"); + + return hits.stream().map(hit -> (String) ((Map) hit).get("_id")).collect(Collectors.toList()); + } + /** * Get the total number of graphs in the cache across all nodes */