Skip to content

Commit

Permalink
Rename radial search parameters score and distance to min_score and m…
Browse files Browse the repository at this point in the history
…ax_distance

Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 10, 2024
1 parent 149350d commit 19bbca6
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 99 deletions.
104 changes: 52 additions & 52 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final ParseField SCORE_FIELD = new ParseField("score");
public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");
public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -64,8 +64,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final String fieldName;
private final float[] vector;
private int k = 0;
private Float distance = null;
private Float score = null;
private Float max_distance = null;
private Float min_score = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

Expand Down Expand Up @@ -98,7 +98,7 @@ public KNNQueryBuilder k(Integer k) {
if (k == null) {
throw new IllegalArgumentException("[" + NAME + "] requires k to be set");
}
validateSingleQueryType(k, distance, score);
validateSingleQueryType(k, max_distance, min_score);
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
Expand All @@ -107,33 +107,33 @@ public KNNQueryBuilder k(Integer k) {
}

/**
* Builder method for distance
* Builder method for max_distance
*
* @param distance the distance threshold for the nearest neighbours
* @param max_distance the max_distance threshold for the nearest neighbours
*/
public KNNQueryBuilder distance(Float distance) {
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
public KNNQueryBuilder maxDistance(Float max_distance) {
if (max_distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set");
}
validateSingleQueryType(k, distance, score);
this.distance = distance;
validateSingleQueryType(k, max_distance, min_score);
this.max_distance = max_distance;
return this;
}

/**
* Builder method for score
* Builder method for min_score
*
* @param score the score threshold for the nearest neighbours
* @param min_score the min_score threshold for the nearest neighbours
*/
public KNNQueryBuilder score(Float score) {
if (score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be set");
public KNNQueryBuilder minScore(Float min_score) {
if (min_score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set");
}
validateSingleQueryType(k, distance, score);
if (score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0");
validateSingleQueryType(k, max_distance, min_score);
if (min_score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0");
}
this.score = score;
this.min_score = min_score;
return this;
}

Expand Down Expand Up @@ -180,8 +180,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
this.score = null;
this.max_distance = null;
this.min_score = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -217,10 +217,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
ignoreUnmapped = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
max_distance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
score = in.readOptionalFloat();
min_score = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
Expand All @@ -232,8 +232,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float distance = null;
Float score = null;
Float max_distance = null;
Float min_score = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -262,10 +262,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -295,7 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

validateSingleQueryType(k, distance, score);
validateSingleQueryType(k, max_distance, min_score);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand All @@ -304,10 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

if (k != null) {
knnQueryBuilder.k(k);
} else if (distance != null) {
knnQueryBuilder.distance(distance);
} else if (score != null) {
knnQueryBuilder.score(score);
} else if (max_distance != null) {
knnQueryBuilder.maxDistance(max_distance);
} else if (min_score != null) {
knnQueryBuilder.minScore(min_score);
}

return knnQueryBuilder;
Expand All @@ -323,10 +323,10 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalBoolean(ignoreUnmapped);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
out.writeOptionalFloat(max_distance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(score);
out.writeOptionalFloat(min_score);
}
}

Expand All @@ -348,12 +348,12 @@ public int getK() {
return this.k;
}

public float getDistance() {
return this.distance;
public float getMax_distance() {
return this.max_distance;
}

public float getScore() {
return this.score;
public float getMin_score() {
return this.min_score;
}

public QueryBuilder getFilter() {
Expand Down Expand Up @@ -384,14 +384,14 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (distance != null) {
builder.field(DISTANCE_FIELD.getPreferredName(), distance);
if (max_distance != null) {
builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (score != null) {
builder.field(SCORE_FIELD.getPreferredName(), score);
if (min_score != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score);
}
printBoostAndQueryName(builder);
builder.endObject();
Expand Down Expand Up @@ -435,18 +435,18 @@ protected Query doToQuery(QueryShardContext context) {
// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
if (this.max_distance != null) {
if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType);
}
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType);
}

if (this.score != null) {
if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
if (this.min_score != null) {
if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType);
}
radius = knnEngine.scoreToRadialThreshold(this.score, spaceType);
radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType);
}

if (fieldDimension != vector.length) {
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,7 @@ private void validateRadiusSearchResults(
if (distanceThreshold != null) {
queryBuilder.field("distance", distanceThreshold);
} else if (scoreThreshold != null) {
queryBuilder.field("score", scoreThreshold);
queryBuilder.field("min_score", scoreThreshold);
} else {
throw new IllegalArgumentException("Invalid threshold");
}
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ private void validateRadiusSearchResults(
builder.startObject(FIELD_NAME);
builder.field("vector", searchVectors[i]);
if (distanceThreshold != null) {
builder.field("distance", distanceThreshold);
builder.field("max_distance", distanceThreshold);
} else if (scoreThreshold != null) {
builder.field("score", scoreThreshold);
builder.field("min_score", scoreThreshold);
} else {
throw new IllegalArgumentException("Either distance or score must be provided");
}
Expand Down
Loading

0 comments on commit 19bbca6

Please sign in to comment.