Skip to content

Commit

Permalink
Support score threshold in radial search
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 3, 2024
1 parent c369ec7 commit 7de0298
Show file tree
Hide file tree
Showing 12 changed files with 649 additions and 85 deletions.
90 changes: 76 additions & 14 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final ParseField SCORE_FIELD = new ParseField("score");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -64,6 +65,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final float[] vector;
private int k = 0;
private Float distance = null;
private Float score = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

Expand Down Expand Up @@ -92,13 +94,14 @@ public KNNQueryBuilder(String fieldName, float[] vector) {
*
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder k(int k) {
public KNNQueryBuilder k(Integer k) {
if (k == null) {
throw new IllegalArgumentException("[" + NAME + "] requires k to be set");

Check warning on line 99 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L99

Added line #L99 was not covered by tests
}
validSingleQueryType(k, distance, score);
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
if (distance != null) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
this.k = k;
return this;
}
Expand All @@ -112,13 +115,28 @@ public KNNQueryBuilder distance(Float distance) {
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
}
if (k != 0) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validSingleQueryType(k, distance, score);
this.distance = distance;
return this;
}

/**
* Builder method for score
*
* @param score the score threshold for the nearest neighbours
*/
public KNNQueryBuilder score(Float score) {
if (score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be set");
}
validSingleQueryType(k, distance, score);
if (score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0");
}
this.score = score;
return this;
}

/**
* Builder method for filter
*
Expand Down Expand Up @@ -163,6 +181,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
this.score = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -200,6 +219,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
score = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -211,6 +233,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float distance = null;
Float score = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -241,6 +264,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -270,9 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

if ((k != null && distance != null) || (k == null && distance == null)) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validSingleQueryType(k, distance, score);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand All @@ -281,8 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

if (k != null) {
knnQueryBuilder.k(k);
} else {
} else if (distance != null) {
knnQueryBuilder.distance(distance);
} else if (score != null) {
knnQueryBuilder.score(score);
}

return knnQueryBuilder;
Expand All @@ -300,6 +325,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(score);
}
}

/**
Expand All @@ -324,6 +352,10 @@ public float getDistance() {
return this.distance;
}

public float getScore() {
return this.score;
}

public QueryBuilder getFilter() {
return this.filter;
}
Expand Down Expand Up @@ -358,6 +390,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (score != null) {
builder.field(SCORE_FIELD.getPreferredName(), score);

Check warning on line 394 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L394

Added line #L394 was not covered by tests
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -397,8 +432,8 @@ protected Query doToQuery(QueryShardContext context) {
spaceType = knnMethodContext.getSpaceType();
}

// Currently, k-NN supports distance type radius search.
// We need transform distance radius to right type of engine required radius.
// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
Expand All @@ -407,6 +442,13 @@ protected Query doToQuery(QueryShardContext context) {
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
}

if (this.score != null) {
if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType);
}
radius = knnEngine.scoreToRadialThreshold(this.score, spaceType);
}

if (fieldDimension != vector.length) {
throw new IllegalArgumentException(
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension)
Expand Down Expand Up @@ -464,7 +506,7 @@ protected Query doToQuery(QueryShardContext context) {
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set");

Check warning on line 509 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L509

Added line #L509 was not covered by tests
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down Expand Up @@ -499,4 +541,24 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static void validSingleQueryType(Integer k, Float distance, Float score) {
int countSetFields = 0;

if (k != null && k != 0) {
countSetFields++;
}
if (distance != null) {
countSetFields++;
}
if (score != null) {
countSetFields++;
}

if (countSetFields != 1) {
throw new IllegalArgumentException(

Check warning on line 559 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L559

Added line #L559 was not covered by tests
"[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score"
);
}
}
}
27 changes: 25 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
* Implements NativeLibrary for the faiss native library
*/
class Faiss extends NativeLibrary {
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
// about the compatibility version the file is created with. In the future, we should refactor this so that it
Expand All @@ -68,6 +69,14 @@ class Faiss extends NativeLibrary {
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
);

// Map that transforms radial search score threshold to faiss required distance
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1)
.put(SpaceType.L2, score -> 1 / score - 1)
.build();

// Define encoders supported by faiss
private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext(
KNNConstants.ENCODER_FLAT,
Expand Down Expand Up @@ -301,7 +310,13 @@ class Faiss extends NativeLibrary {
).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build()
);

final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION);
final static Faiss INSTANCE = new Faiss(
METHODS,
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
);

/**
* Constructor for Faiss
Expand All @@ -315,9 +330,11 @@ private Faiss(
Map<String, KNNMethod> methods,
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
}

@Override
Expand All @@ -326,6 +343,12 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
return this.scoreTransform.get(spaceType).apply(score);
}

/**
* MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index.
* Faiss's index factory takes an "index description" that it uses to build the index. In this description,
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return knnLibrary.distanceToRadialThreshold(distance, spaceType);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return knnLibrary.scoreToRadialThreshold(score, spaceType);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return knnLibrary.validateMethod(knnMethodContext);
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public interface KNNLibrary {
*/
Float distanceToRadialThreshold(Float distance, SpaceType spaceType);

/**
* Translate the score threshold input from end user to the engine's threshold.
*
* @param score score threshold input from end user
* @param spaceType spaceType used to compute the threshold
*
* @return transformed score for the library
*/
Float scoreToRadialThreshold(Float score, SpaceType spaceType);

/**
* Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is
* deemed invalid.
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class Lucene extends JVMLibrary {
Function<Float, Float>>builder()
.put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2)
.put(SpaceType.L2, distance -> 1 / (1 + distance))
.put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1)
.build();

final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS);
Expand Down Expand Up @@ -93,6 +94,12 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return this.distanceTransform.get(spaceType).apply(distance);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Lucene engine uses distance as is and does not need transformation
return score;
}

@Override
public List<String> mmapFileExtensions() {
return List.of("vec", "vex");
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Nmslib.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ private Nmslib(
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return score;

Check warning on line 78 in src/main/java/org/opensearch/knn/index/util/Nmslib.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/util/Nmslib.java#L78

Added line #L78 was not covered by tests
}
}
Loading

0 comments on commit 7de0298

Please sign in to comment.