Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for radial search on Neural query #1235

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Added
- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166))
- Added `minScore` and `maxDistance` to `NeuralQuery` ([#1235](https://github.com/opensearch-project/opensearch-java/pull/1235))

### Dependencies

Expand Down Expand Up @@ -567,4 +568,4 @@ This section is for maintaining a changelog for all breaking changes for the cli
[2.5.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.4.0...v2.5.0
[2.4.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.3.0...v2.4.0
[2.3.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.2.0...v2.3.0
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public final float[] vector() {
* Optional - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
@Nullable
public final Integer k() {
return this.k;
}
Expand All @@ -84,6 +85,7 @@ public final Integer k() {
* Optional - The minimum score allowed for the returned search results.
* @return The minimum score allowed for the returned search results.
*/
@Nullable
private final Float minScore() {
return this.minScore;
}
Expand All @@ -92,6 +94,7 @@ private final Float minScore() {
* Optional - The maximum distance allowed between the vector and each of the returned search results.
* @return The maximum distance allowed between the vector and each ofthe returned search results.
*/
@Nullable
private final Float maxDistance() {
return this.maxDistance;
}
Expand All @@ -111,8 +114,6 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);

// TODO: Implement the rest of the serialization.

generator.writeKey("vector");
generator.writeStartArray();
for (float value : this.vector) {
Expand Down Expand Up @@ -183,7 +184,7 @@ public Builder vector(@Nullable float[] vector) {
}

/**
* Required - The number of neighbors the search of each graph will return.
* Optional - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ public class NeuralQuery extends QueryBase implements QueryVariant {
private final String field;
private final String queryText;
private final String queryImage;
private final int k;
@Nullable
private final Integer k;
@Nullable
private final Float minScore;
@Nullable
private final Float maxDistance;
@Nullable
private final String modelId;
@Nullable
Expand All @@ -41,7 +46,9 @@ private NeuralQuery(NeuralQuery.Builder builder) {
}
this.queryText = builder.queryText;
this.queryImage = builder.queryImage;
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.k = builder.k;
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.modelId = builder.modelId;
this.filter = builder.filter;
}
Expand Down Expand Up @@ -90,17 +97,34 @@ public final String queryImage() {
}

/**
* Required - The number of neighbors to return.
* Optional - The number of neighbors to return.
*
* @return The number of neighbors to return.
*/
public final int k() {
@Nullable
public final Integer k() {
return this.k;
}

/**
* Builder for {@link NeuralQuery}.
* Optional - The minimum score threshold for the search results
*
* @return The minimum score threshold for the search results
*/
@Nullable
public final Float minScore() {
return this.minScore;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @return The maximum distance threshold for the search results
*/
@Nullable
public final Float maxDistance() {
return this.maxDistance;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
Expand Down Expand Up @@ -141,7 +165,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.write("model_id", this.modelId);
}

generator.write("k", this.k);
if (this.k != null) {
generator.write("k", this.k);
}

if (this.minScore != null) {
generator.write("min_score", this.minScore);
}

if (this.maxDistance != null) {
generator.write("max_distance", this.maxDistance);
}

if (this.filter != null) {
generator.writeKey("filter");
Expand All @@ -152,7 +186,14 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
return toBuilder(new Builder()).field(field)
.queryText(queryText)
.queryImage(queryImage)
.k(k)
.minScore(minScore)
.maxDistance(maxDistance)
.modelId(modelId)
.filter(filter);
}

/**
Expand All @@ -162,8 +203,13 @@ public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builde
private String field;
private String queryText;
private String queryImage;
@Nullable
private Integer k;
@Nullable
private Float minScore;
@Nullable
private Float maxDistance;
@Nullable
private String modelId;
@Nullable
private Query filter;
Expand Down Expand Up @@ -216,7 +262,7 @@ public NeuralQuery.Builder modelId(@Nullable String modelId) {
}

/**
* Required - The number of neighbors to return.
* Optional - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
Expand All @@ -226,6 +272,28 @@ public NeuralQuery.Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - The minimum score threshold for the search results
*
* @param minScore The minimum score threshold for the search results
* @return This builder.
*/
public NeuralQuery.Builder minScore(@Nullable Float minScore) {
this.minScore = minScore;
return this;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @param maxDistance The maximum distance threshold for the search results
* @return This builder.
*/
public NeuralQuery.Builder maxDistance(@Nullable Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
Expand Down Expand Up @@ -267,6 +335,8 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(NeuralQuery.Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public void testNeuralQuery() {
assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
assertEquals((Integer) 100, searchRequest.query().neural().k());
}

@Test
Expand Down Expand Up @@ -251,7 +251,7 @@ public void testNeuralQueryFromJson() {
searchRequest.query().neural().queryImage()
);
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
assertEquals((Integer) 100, searchRequest.query().neural().k());
}

@Test
Expand Down Expand Up @@ -279,10 +279,10 @@ public void testHybridQuery() {
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals((Integer) 100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals((Integer) 2, searchRequest.query().hybrid().queries().get(2).knn().k());
}

@Test
Expand All @@ -301,9 +301,9 @@ public void testHybridQueryFromJson() {
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals((Integer) 100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals((Integer) 2, searchRequest.query().hybrid().queries().get(2).knn().k());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch._types.Refresh;
import org.opensearch.client.opensearch._types.mapping.Property;
Expand Down Expand Up @@ -89,6 +90,14 @@ public void testTemplateSearchAggregations() throws Exception {

@Test
public void testMultiSearchTemplate() throws Exception {
Integer expectedSuccessStatus = null;
Integer expectedFailureStatus = null;

if (getServerVersion().onOrAfter(Version.V_2_18_0)) {
expectedSuccessStatus = 200;
expectedFailureStatus = 404;
}

var index = "test-msearch-template";
createDocuments(index);

Expand Down Expand Up @@ -120,11 +129,11 @@ public void testMultiSearchTemplate() throws Exception {
assertEquals(2, searchResponse.responses().size());
var response = searchResponse.responses().get(0);
assertTrue(response.isResult());
assertNull(response.result().status());
assertEquals(expectedSuccessStatus, response.result().status());
assertEquals(4, response.result().hits().hits().size());
var failureResponse = searchResponse.responses().get(1);
assertTrue(failureResponse.isFailure());
assertNull(failureResponse.failure().status());
assertEquals(expectedFailureStatus, failureResponse.failure().status());
}

private SearchTemplateResponse<SimpleDoc> sendTemplateRequest(String index, String title, boolean suggs, boolean aggs)
Expand Down
Loading