Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename radial search parameters score and distance to min_score and max_distance #1609

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 54 additions & 54 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final ParseField SCORE_FIELD = new ParseField("score");
public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");
public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -64,17 +64,17 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final String fieldName;
private final float[] vector;
private int k = 0;
private Float distance = null;
private Float score = null;
private Float max_distance = null;
private Float min_score = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

/**
* Constructs a new query with the given field name and vector
* Constructs a new query with the given field name and vector
*
* @param fieldName Name of the field
* @param vector Array of floating points
*/
*/
public KNNQueryBuilder(String fieldName, float[] vector) {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException("[" + NAME + "] requires fieldName");
Expand All @@ -98,7 +98,7 @@ public KNNQueryBuilder k(Integer k) {
if (k == null) {
throw new IllegalArgumentException("[" + NAME + "] requires k to be set");
}
validateSingleQueryType(k, distance, score);
validateSingleQueryType(k, max_distance, min_score);
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
Expand All @@ -107,33 +107,33 @@ public KNNQueryBuilder k(Integer k) {
}

/**
* Builder method for distance
* Builder method for max_distance
*
* @param distance the distance threshold for the nearest neighbours
* @param max_distance the max_distance threshold for the nearest neighbours
*/
public KNNQueryBuilder distance(Float distance) {
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
public KNNQueryBuilder maxDistance(Float max_distance) {
if (max_distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set");
}
validateSingleQueryType(k, distance, score);
this.distance = distance;
validateSingleQueryType(k, max_distance, min_score);
this.max_distance = max_distance;
return this;
}

/**
* Builder method for score
* Builder method for min_score
*
* @param score the score threshold for the nearest neighbours
* @param min_score the min_score threshold for the nearest neighbours
*/
public KNNQueryBuilder score(Float score) {
if (score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be set");
public KNNQueryBuilder minScore(Float min_score) {
if (min_score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set");
}
validateSingleQueryType(k, distance, score);
if (score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0");
validateSingleQueryType(k, max_distance, min_score);
if (min_score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0");
}
this.score = score;
this.min_score = min_score;
return this;
}

Expand Down Expand Up @@ -180,8 +180,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
this.score = null;
this.max_distance = null;
this.min_score = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -217,10 +217,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
ignoreUnmapped = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
max_distance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
score = in.readOptionalFloat();
min_score = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
Expand All @@ -232,8 +232,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float distance = null;
Float score = null;
Float max_distance = null;
Float min_score = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -262,10 +262,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -295,7 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

validateSingleQueryType(k, distance, score);
validateSingleQueryType(k, max_distance, min_score);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand All @@ -304,10 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

if (k != null) {
knnQueryBuilder.k(k);
} else if (distance != null) {
knnQueryBuilder.distance(distance);
} else if (score != null) {
knnQueryBuilder.score(score);
} else if (max_distance != null) {
knnQueryBuilder.maxDistance(max_distance);
} else if (min_score != null) {
knnQueryBuilder.minScore(min_score);
}

return knnQueryBuilder;
Expand All @@ -323,10 +323,10 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalBoolean(ignoreUnmapped);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
out.writeOptionalFloat(max_distance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(score);
out.writeOptionalFloat(min_score);
}
}

Expand All @@ -348,12 +348,12 @@ public int getK() {
return this.k;
}

public float getDistance() {
return this.distance;
public float getMax_distance() {
junqiu-lei marked this conversation as resolved.
Show resolved Hide resolved
return this.max_distance;
}

public float getScore() {
return this.score;
public float getMin_score() {
junqiu-lei marked this conversation as resolved.
Show resolved Hide resolved
return this.min_score;
}

public QueryBuilder getFilter() {
Expand Down Expand Up @@ -384,14 +384,14 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (distance != null) {
builder.field(DISTANCE_FIELD.getPreferredName(), distance);
if (max_distance != null) {
builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (score != null) {
builder.field(SCORE_FIELD.getPreferredName(), score);
if (min_score != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score);
}
printBoostAndQueryName(builder);
builder.endObject();
Expand Down Expand Up @@ -435,18 +435,18 @@ protected Query doToQuery(QueryShardContext context) {
// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
if (this.max_distance != null) {
if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType);
}
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType);
}

if (this.score != null) {
if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
if (this.min_score != null) {
if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType);
}
radius = knnEngine.scoreToRadialThreshold(this.score, spaceType);
radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType);
}

if (fieldDimension != vector.length) {
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -1702,9 +1702,9 @@ private void validateRadiusSearchResults(
queryBuilder.startObject(fieldName);
queryBuilder.field("vector", queryVector);
if (distanceThreshold != null) {
queryBuilder.field("distance", distanceThreshold);
queryBuilder.field("max_distance", distanceThreshold);
} else if (scoreThreshold != null) {
queryBuilder.field("score", scoreThreshold);
queryBuilder.field("min_score", scoreThreshold);
} else {
throw new IllegalArgumentException("Invalid threshold");
}
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ private void validateRadiusSearchResults(
builder.startObject(FIELD_NAME);
builder.field("vector", searchVectors[i]);
if (distanceThreshold != null) {
builder.field("distance", distanceThreshold);
builder.field("max_distance", distanceThreshold);
} else if (scoreThreshold != null) {
builder.field("score", scoreThreshold);
builder.field("min_score", scoreThreshold);
} else {
throw new IllegalArgumentException("Either distance or score must be provided");
}
Expand Down
Loading
Loading