diff --git a/CHANGELOG.md b/CHANGELOG.md index e3bb2f33a7..1c5084a235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased 2.x] ### Added - Add search role type for nodes in cluster stats ([#848](https://github.com/opensearch-project/opensearch-java/pull/848)) +- Add support for Hybrid query type ([#850](https://github.com/opensearch-project/opensearch-java/pull/850)) ### Dependencies diff --git a/guides/search.md b/guides/search.md index d343b30420..7fd3e06cc9 100644 --- a/guides/search.md +++ b/guides/search.md @@ -81,6 +81,25 @@ for (int i = 0; i < searchResponse.hits().hits().size(); i++) { } ``` +### Search documents using a hybrid query +```java +Query searchQuery = Query.of( + h -> h.hybrid( + q -> q.queries(Arrays.asList( + new MatchQuery.Builder().field("text").query(FieldValue.of("Text for document 2")).build().toQuery(), + new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(), + new NeuralQuery.Builder().field("passage_embedding").queryText("Hi world").modelId("bQ1J8ooBpBj3wT4HVUsb").k(100).build().toQuery() + ) + ) + ) + ); +SearchRequest searchRequest = new SearchRequest.Builder().query(searchQuery).build(); +SearchResponse searchResponse = client.search(searchRequest, IndexData.class); +for (var hit : searchResponse.hits().hits()) { + LOGGER.info("Found {} with score {}", hit.source(), hit.score()); +} +``` + ### Search documents using suggesters [AppData](../samples/src/main/java/org/opensearch/client/samples/util/AppData.java) refers to the sample data class used in the below samples. diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/HybridQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/HybridQuery.java new file mode 100644 index 0000000000..d2448ff644 --- /dev/null +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/HybridQuery.java @@ -0,0 +1,107 @@ +package org.opensearch.client.opensearch._types.query_dsl; + +import jakarta.json.stream.JsonGenerator; +import java.util.List; +import java.util.function.Function; +import org.opensearch.client.json.JsonpDeserializer; +import org.opensearch.client.json.JsonpMapper; +import org.opensearch.client.json.ObjectBuilderDeserializer; +import org.opensearch.client.json.ObjectDeserializer; +import org.opensearch.client.util.ApiTypeHelper; +import org.opensearch.client.util.ObjectBuilder; + +public class HybridQuery extends QueryBase implements QueryVariant { + private final List queries; + + private HybridQuery(HybridQuery.Builder builder) { + super(builder); + this.queries = ApiTypeHelper.unmodifiable(builder.queries); + } + + public static HybridQuery of(Function> fn) { + return fn.apply(new HybridQuery.Builder()).build(); + } + + /** + * Required - list of search queries. + * + * @return list of queries provided under hybrid clause. + */ + public final List queries() { + return this.queries; + } + + @Override + protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { + super.serializeInternal(generator, mapper); + generator.writeKey("queries"); + generator.writeStartArray(); + for (Query item0 : this.queries) { + item0.serialize(generator, mapper); + } + generator.writeEnd(); + } + + @Override + public Query.Kind _queryKind() { + return Query.Kind.Hybrid; + } + + public HybridQuery.Builder toBuilder() { + return new HybridQuery.Builder().queries(queries); + } + + public static class Builder extends QueryBase.AbstractBuilder implements ObjectBuilder { + private List queries; + + /** + * API name: {@code hybrid} + *

+ * Adds all elements of list to hybrid. + */ + public final HybridQuery.Builder queries(List list) { + this.queries = _listAddAll(this.queries, list); + return this; + } + + /** + * API name: {@code hybrid} + *

+ * Adds one or more values to hybrid. + */ + public final HybridQuery.Builder queries(Query value, Query... values) { + this.queries = _listAdd(this.queries, value, values); + return this; + } + + /** + * API name: {@code hybrid} + *

+ * Adds a value to hybrid using a builder lambda. + */ + public final HybridQuery.Builder queries(Function> fn) { + return queries(fn.apply(new Query.Builder()).build()); + } + + @Override + protected Builder self() { + return this; + } + + @Override + public HybridQuery build() { + _checkSingleUse(); + return new HybridQuery(this); + } + } + + public static final JsonpDeserializer _DESERIALIZER = ObjectBuilderDeserializer.lazy( + HybridQuery.Builder::new, + HybridQuery::setupHybridQueryDeserializer + ); + + protected static void setupHybridQueryDeserializer(ObjectDeserializer op) { + setupQueryBaseDeserializer(op); + op.add(HybridQuery.Builder::queries, JsonpDeserializer.arrayDeserializer(Query._DESERIALIZER), "queries"); + } +} diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java index 6167d02d3d..1510a19473 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java @@ -120,6 +120,8 @@ public enum Kind implements JsonEnum { Neural("neural"), + Hybrid("hybrid"), + ParentId("parent_id"), Percolate("percolate"), @@ -725,6 +727,23 @@ public NeuralQuery neural() { return TaggedUnionUtils.get(this, Kind.Neural); } + /** + * Is this variant instance of kind {@code hybrid}? + */ + public boolean isHybrid() { + return _kind == Kind.Hybrid; + } + + /** + * Get the {@code hybrid} variant value. + * + * @throws IllegalStateException + * if the current variant is not of the {@code hybrid} kind. + */ + public HybridQuery hybrid() { + return TaggedUnionUtils.get(this, Kind.Hybrid); + } + /** * Is this variant instance of kind {@code parent_id}? */ @@ -1510,6 +1529,16 @@ public ObjectBuilder neural(Function hybrid(HybridQuery v) { + this._kind = Kind.Hybrid; + this._value = v; + return this; + } + + public ObjectBuilder hybrid(Function> fn) { + return this.hybrid(fn.apply(new HybridQuery.Builder()).build()); + } + public ObjectBuilder parentId(ParentIdQuery v) { this._kind = Kind.ParentId; this._value = v; @@ -1818,6 +1847,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer op) { op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match"); op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested"); op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural"); + op.add(Builder::hybrid, HybridQuery._DESERIALIZER, "hybrid"); op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id"); op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate"); op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned"); diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java index f165ac7060..8ddf85b1e8 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java @@ -261,6 +261,13 @@ public static NeuralQuery.Builder neural() { return new NeuralQuery.Builder(); } + /** + * Creates a builder for the {@link HybridQuery nested} {@code Query} variant. + */ + public static HybridQuery.Builder hybrid() { + return new HybridQuery.Builder(); + } + /** * Creates a builder for the {@link ParentIdQuery parent_id} {@code Query} * variant. diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/HybridQueryTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/HybridQueryTest.java new file mode 100644 index 0000000000..a897faa9d9 --- /dev/null +++ b/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/HybridQueryTest.java @@ -0,0 +1,27 @@ +package org.opensearch.client.opensearch._types.query_dsl; + +import java.util.Arrays; +import org.junit.Test; +import org.opensearch.client.opensearch._types.FieldValue; +import org.opensearch.client.opensearch.model.ModelTestCase; + +public class HybridQueryTest extends ModelTestCase { + @Test + public void toBuilder() { + HybridQuery origin = new HybridQuery.Builder().queries( + Arrays.asList( + new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(), + new NeuralQuery.Builder().field("passage_embedding") + .queryText("Hi world") + .modelId("bQ1J8ooBpBj3wT4HVUsb") + .k(100) + .build() + .toQuery(), + new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery() + ) + ).build(); + HybridQuery copied = origin.toBuilder().build(); + + assertEquals(toJson(copied), toJson(origin)); + } +} diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/integTest/AbstractSearchRequestIT.java b/java-client/src/test/java/org/opensearch/client/opensearch/integTest/AbstractSearchRequestIT.java index 19f3af617f..9c97e3a17b 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/integTest/AbstractSearchRequestIT.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/integTest/AbstractSearchRequestIT.java @@ -15,10 +15,12 @@ import org.opensearch.client.opensearch._types.FieldValue; import org.opensearch.client.opensearch._types.SortOrder; import org.opensearch.client.opensearch._types.mapping.Property; +import org.opensearch.client.opensearch._types.query_dsl.MatchQuery; import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.client.opensearch._types.query_dsl.TermQuery; import org.opensearch.client.opensearch.core.SearchRequest; import org.opensearch.client.opensearch.core.SearchResponse; +import org.opensearch.client.opensearch.indices.DeleteIndexRequest; import org.opensearch.client.opensearch.indices.SegmentSortOrder; public abstract class AbstractSearchRequestIT extends OpenSearchJavaClientTestCase { @@ -26,21 +28,7 @@ public abstract class AbstractSearchRequestIT extends OpenSearchJavaClientTestCa @Test public void shouldReturnSearchResults() throws Exception { final String index = "search_request"; - assertTrue( - javaClient().indices() - .create( - b -> b.index(index) - .mappings( - m -> m.properties("name", Property.of(p -> p.keyword(v -> v.docValues(true)))) - .properties("size", Property.of(p -> p.keyword(v -> v.docValues(true)))) - ) - .settings(settings -> settings.sort(s -> s.field("name").order(SegmentSortOrder.Asc))) - ) - .acknowledged() - ); - - createTestDocuments(index); - javaClient().indices().refresh(); + createIndex(index); final Query query = Query.of( q -> q.bool( @@ -72,23 +60,47 @@ public void shouldReturnSearchResults() throws Exception { } @Test - public void shouldReturnSearchResultsWithoutStoredFields() throws Exception { - final String index = "search_request"; - assertTrue( - javaClient().indices() - .create( - b -> b.index(index) - .mappings( - m -> m.properties("name", Property.of(p -> p.keyword(v -> v.docValues(true)))) - .properties("size", Property.of(p -> p.keyword(v -> v.docValues(true)))) - ) - .settings(settings -> settings.sort(s -> s.field("name").order(SegmentSortOrder.Asc))) + public void hybridSearchShouldReturnSearchResults() throws Exception { + final String index = "hybrid_search_request"; + try { + createIndex(index); + final Query query = Query.of( + h -> h.hybrid( + q -> q.queries(Arrays.asList(new MatchQuery.Builder().field("size").query(FieldValue.of("huge")).build().toQuery())) ) - .acknowledged() - ); + ); + + final SearchRequest request = SearchRequest.of( + r -> r.index(index) + .sort(s -> s.field(f -> f.field("name").order(SortOrder.Asc))) + .fields(f -> f.field("name")) + .query(query) + .source(s -> s.fetch(true)) + ); + + final SearchResponse response = javaClient().search(request, ShopItem.class); + assertEquals(response.hits().hits().size(), 5); + + assertTrue( + Arrays.stream(response.hits().hits().get(2).fields().get("name").to(String[].class)) + .collect(Collectors.toList()) + .contains("hummer") + ); + assertTrue( + Arrays.stream(response.hits().hits().get(3).fields().get("name").to(String[].class)) + .collect(Collectors.toList()) + .contains("jammer") + ); + } finally { + DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest.Builder().index(index).build(); + javaClient().indices().delete(deleteIndexRequest); + } + } - createTestDocuments(index); - javaClient().indices().refresh(); + @Test + public void shouldReturnSearchResultsWithoutStoredFields() throws Exception { + final String index = "search_request"; + createIndex(index); final Query query = Query.of( q -> q.bool( @@ -117,6 +129,23 @@ private void createTestDocuments(String index) throws IOException { javaClient().create(_1 -> _1.index(index).id("8").document(createItem("nuts", "small", "no", 2))); } + private void createIndex(String index) throws IOException { + assertTrue( + javaClient().indices() + .create( + b -> b.index(index) + .mappings( + m -> m.properties("name", Property.of(p -> p.keyword(v -> v.docValues(true)))) + .properties("size", Property.of(p -> p.keyword(v -> v.docValues(true)))) + ) + .settings(settings -> settings.sort(s -> s.field("name").order(SegmentSortOrder.Asc))) + ) + .acknowledged() + ); + createTestDocuments(index); + javaClient().indices().refresh(); + } + private ShopItem createItem(String name, String size, String company, int quantity) { return new ShopItem(name, size, company, quantity); } diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java index e504050c19..addb6af775 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java @@ -32,12 +32,17 @@ package org.opensearch.client.opensearch.model; +import java.util.Arrays; import org.junit.Test; import org.opensearch.client.json.JsonData; +import org.opensearch.client.opensearch._types.FieldValue; import org.opensearch.client.opensearch._types.mapping.Property; import org.opensearch.client.opensearch._types.mapping.TypeMapping; +import org.opensearch.client.opensearch._types.query_dsl.KnnQuery; +import org.opensearch.client.opensearch._types.query_dsl.NeuralQuery; import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.client.opensearch._types.query_dsl.QueryBuilders; +import org.opensearch.client.opensearch._types.query_dsl.TermQuery; import org.opensearch.client.opensearch.core.SearchRequest; import org.opensearch.client.opensearch.indices.GetMappingResponse; @@ -243,4 +248,57 @@ public void testNeuralQueryFromJson() { assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId()); assertEquals(100, searchRequest.query().neural().k()); } + + @Test + public void testHybridQuery() { + + Query query = Query.of( + h -> h.hybrid( + q -> q.queries( + Arrays.asList( + new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(), + new NeuralQuery.Builder().field("passage_embedding") + .queryText("Hi world") + .modelId("bQ1J8ooBpBj3wT4HVUsb") + .k(100) + .build() + .toQuery(), + new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery() + ) + ) + ) + ); + SearchRequest searchRequest = SearchRequest.of(s -> s.query(query)); + assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field()); + assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue()); + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field()); + assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText()); + assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId()); + assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); + } + + @Test + public void testHybridQueryFromJson() { + + String json = "{\"query\"" + + ":{\"hybrid\":{\"queries\":[{\"term\":{\"passage_text\":\"Foo bar\"}}," + + "{\"neural\":{\"passage_embedding\":{\"query_text\":\"Hi world\",\"model_id\":\"bQ1J8ooBpBj3wT4HVUsb\",\"k\":100}}}," + + "{\"knn\":{\"passage_embedding\":{\"vector\":[0.01,0.02],\"k\":2}}}]}},\"size\":10" + + "}"; + + SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper); + + assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field()); + assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue()); + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field()); + assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText()); + assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId()); + assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); + } } diff --git a/samples/src/main/java/org/opensearch/client/samples/Search.java b/samples/src/main/java/org/opensearch/client/samples/Search.java index ce357d2942..417ecba409 100644 --- a/samples/src/main/java/org/opensearch/client/samples/Search.java +++ b/samples/src/main/java/org/opensearch/client/samples/Search.java @@ -26,6 +26,8 @@ import org.opensearch.client.opensearch._types.mapping.Property; import org.opensearch.client.opensearch._types.mapping.TextProperty; import org.opensearch.client.opensearch._types.mapping.TypeMapping; +import org.opensearch.client.opensearch._types.query_dsl.MatchQuery; +import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.SearchRequest; import org.opensearch.client.opensearch.core.SearchResponse; @@ -104,6 +106,20 @@ public static void main(String[] args) { entry.getValue().sterms().buckets().array().forEach(b -> LOGGER.info("{} : {}", b.key(), b.docCount())); } + // HybridSearch + Query searchQuery = Query.of( + h -> h.hybrid( + q -> q.queries( + Arrays.asList(new MatchQuery.Builder().field("text").query(FieldValue.of("Text for document 2")).build().toQuery()) + ) + ) + ); + searchRequest = new SearchRequest.Builder().query(searchQuery).build(); + searchResponse = client.search(searchRequest, IndexData.class); + for (var hit : searchResponse.hits().hits()) { + LOGGER.info("Found {} with score {}", hit.source(), hit.score()); + } + LOGGER.info("Deleting index {}", indexName); DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest.Builder().index(indexName).build(); client.indices().delete(deleteIndexRequest);