-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle multi-vector in exact search scenario
Signed-off-by: Heemin Kim <[email protected]>
- Loading branch information
Showing
12 changed files
with
519 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
src/main/java/org/opensearch/knn/index/query/filtered/FilteredKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene | ||
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 | ||
* | ||
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. | ||
*/ | ||
public abstract class FilteredKNNIterator<T extends KNNQueryVector> { | ||
// Array of doc ids to iterate | ||
protected final int[] filterIdsArray; | ||
protected float currentScore = Float.NEGATIVE_INFINITY; | ||
protected final T queryVector; | ||
protected final BinaryDocValues values; | ||
protected final SpaceType spaceType; | ||
protected int currentPos = 0; | ||
|
||
public FilteredKNNIterator(final int[] filterIdsArray, final T queryVector, final BinaryDocValues values, final SpaceType spaceType) { | ||
this.filterIdsArray = filterIdsArray; | ||
this.queryVector = queryVector; | ||
this.values = values; | ||
this.spaceType = spaceType; | ||
} | ||
|
||
/** | ||
* Advance to the next doc and update score | ||
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs | ||
* | ||
* @return next doc id | ||
*/ | ||
abstract public int nextDoc() throws IOException; | ||
|
||
/** | ||
* Return a score of current doc | ||
* | ||
* @return current score | ||
*/ | ||
public float score() { | ||
return currentScore; | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
src/main/java/org/opensearch/knn/index/query/filtered/KNNFloatQueryVector.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import lombok.NonNull; | ||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.apache.lucene.util.BytesRef; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.codec.util.KNNVectorSerializer; | ||
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.IOException; | ||
|
||
/** | ||
* Implementation of KNNQueryVector with float data type | ||
*/ | ||
public class KNNFloatQueryVector implements KNNQueryVector { | ||
private final float[] queryVector; | ||
|
||
public KNNFloatQueryVector(final float[] queryVector) { | ||
this.queryVector = queryVector; | ||
} | ||
|
||
// Calculate similarity between queryVector and values in current position using given space type | ||
public float score(@NonNull final BinaryDocValues values, @NonNull final SpaceType spaceType) throws IOException { | ||
final BytesRef value = values.binaryValue(); | ||
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); | ||
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); | ||
final float[] vector = vectorSerializer.byteToFloatArray(byteStream); | ||
// Calculates a similarity score between the two vectors with a specified function. Higher similarity | ||
// scores correspond to closer vectors. | ||
return spaceType.getVectorSimilarityFunction().compare(queryVector, vector); | ||
} | ||
} |
26 changes: 26 additions & 0 deletions
26
src/main/java/org/opensearch/knn/index/query/filtered/KNNQueryVector.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Wrapper interface on knn query vector to hide its type(float, byte) and provide score for given doc value and space type | ||
*/ | ||
public interface KNNQueryVector { | ||
/** | ||
* Return score of values using the spaceType | ||
* | ||
* @param values doc value | ||
* @param spaceType space type to calculate score | ||
* @return score of the doc value | ||
* @throws IOException | ||
*/ | ||
float score(final BinaryDocValues values, final SpaceType spaceType) throws IOException; | ||
} |
55 changes: 55 additions & 0 deletions
55
src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.BitSet; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc | ||
* of which ID is set in parentBitSet and only return best child doc with the highest score. | ||
*/ | ||
public class NestedFilteredKNNIterator<T extends KNNQueryVector> extends FilteredKNNIterator<T> { | ||
private final BitSet parentBitSet; | ||
private int currentParent = -1; | ||
private int bestChild = -1; | ||
|
||
public NestedFilteredKNNIterator( | ||
final int[] filterIdsArray, | ||
final T queryVector, | ||
final BinaryDocValues values, | ||
final SpaceType spaceType, | ||
final BitSet parentBitSet | ||
) { | ||
super(filterIdsArray, queryVector, values, spaceType); | ||
this.parentBitSet = parentBitSet; | ||
} | ||
|
||
@Override | ||
public int nextDoc() throws IOException { | ||
if (currentPos >= filterIdsArray.length) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
currentScore = Float.NEGATIVE_INFINITY; | ||
currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]); | ||
while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent) { | ||
int currentChild = filterIdsArray[currentPos]; | ||
values.advance(currentChild); | ||
final float score = queryVector.score(values, spaceType); | ||
if (score > currentScore) { | ||
bestChild = currentChild; | ||
currentScore = score; | ||
} | ||
currentPos++; | ||
} | ||
|
||
return bestChild; | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
src/main/java/org/opensearch/knn/index/query/filtered/PlainFilteredKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.index.BinaryDocValues; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Basic implementation of FilteredKNNIterator which iterate all doc IDs in filterIdsArray | ||
*/ | ||
public class PlainFilteredKNNIterator<T extends KNNQueryVector> extends FilteredKNNIterator<T> { | ||
private int currentDoc = -1; | ||
|
||
public PlainFilteredKNNIterator( | ||
final int[] filterIdsArray, | ||
final T queryVector, | ||
final BinaryDocValues values, | ||
final SpaceType spaceType | ||
) { | ||
super(filterIdsArray, queryVector, values, spaceType); | ||
} | ||
|
||
@Override | ||
public int nextDoc() throws IOException { | ||
if (currentPos >= filterIdsArray.length) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
currentDoc = values.advance(filterIdsArray[currentPos++]); | ||
currentScore = queryVector.score(values, spaceType); | ||
return currentDoc; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.common; | ||
|
||
public class Constants { | ||
public static final String FIELD_FILTER = "filter"; | ||
public static final String FIELD_TERM = "term"; | ||
} |
Oops, something went wrong.