From f5a47a6cc829aed86f7db7296d7365675b31f04d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 22 Apr 2024 17:14:54 +0800 Subject: [PATCH] enhancements: support neural_sparse query by tokens (#693) * enhancements: support neural sparse query by tokens Signed-off-by: zhichao-aws * add empty string logics for streams Signed-off-by: zhichao-aws * fix fromXContent error case logic Signed-off-by: zhichao-aws * add ut Signed-off-by: zhichao-aws * ut it Signed-off-by: zhichao-aws * updates for comments Signed-off-by: zhichao-aws * nit Signed-off-by: zhichao-aws * fix comments Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 + .../query/NeuralSparseQueryBuilder.java | 101 ++++++++++---- .../query/NeuralSparseQueryBuilderTests.java | 125 +++++++++++++++++- .../query/NeuralSparseQueryIT.java | 47 ++++++- 4 files changed, 243 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eae99470c..a40e3edb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - BWC tests for text chunking processor ([#661](https://github.com/opensearch-project/neural-search/pull/661)) - Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670)) +- Allowing query by raw tokens in neural_sparse query ([#693](https://github.com/opensearch-project/neural-search/pull/693)) ### Bug Fixes - Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 319f4b356..6c3b06967 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -62,19 +63,15 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder> queryTokensSupplier; private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0; + public static void initialize(MLCommonsClientAccessor mlClient) { + NeuralSparseQueryBuilder.ML_CLIENT = mlClient; + } + /** * Constructor from stream input * @@ -102,21 +103,31 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); this.queryTokensSupplier = () -> queryTokens; } + // to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API + // after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead + if (StringUtils.EMPTY.equals(this.queryText)) { + this.queryText = null; + } + if (StringUtils.EMPTY.equals(this.modelId)) { + this.modelId = null; + } } @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(fieldName); - out.writeString(queryText); + out.writeString(this.fieldName); + // to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API + // after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead + out.writeString(StringUtils.defaultString(this.queryText, StringUtils.EMPTY)); if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { out.writeOptionalString(this.modelId); } else { - out.writeString(this.modelId); + out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY)); } out.writeOptionalFloat(maxTokenScore); - if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) { + if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) { out.writeBoolean(true); - out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); + out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); } else { out.writeBoolean(false); } @@ -126,11 +137,16 @@ protected void doWriteTo(StreamOutput out) throws IOException { protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); - xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + if (Objects.nonNull(queryText)) { + xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + } if (Objects.nonNull(modelId)) { xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); } - if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore); + if (Objects.nonNull(maxTokenScore)) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore); + if (Objects.nonNull(queryTokensSupplier) && Objects.nonNull(queryTokensSupplier.get())) { + xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokensSupplier.get()); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -144,6 +160,16 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * "max_token_score": float (optional) * } * + * or + * "SAMPLE_FIELD": { + * "query_tokens": { + * "token_a": float, + * "token_b": float, + * ... + * } + * } + * + * * @param parser XContentParser * @return NeuralQueryBuilder * @throws IOException can be thrown by parser @@ -171,16 +197,40 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw } requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); - requireValue( - sparseEncodingQueryBuilder.queryText(), - String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) - ); - if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + if (Objects.isNull(sparseEncodingQueryBuilder.queryTokensSupplier())) { requireValue( - sparseEncodingQueryBuilder.modelId(), - String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) + sparseEncodingQueryBuilder.queryText(), + String.format( + Locale.ROOT, + "either %s field or %s field must be provided for [%s] query", + QUERY_TEXT_FIELD.getPreferredName(), + QUERY_TOKENS_FIELD.getPreferredName(), + NAME + ) + ); + if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format( + Locale.ROOT, + "using %s, %s field must be provided for [%s] query", + QUERY_TEXT_FIELD.getPreferredName(), + MODEL_ID_FIELD.getPreferredName(), + NAME + ) + ); + } + } + + if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.queryText())) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s field can not be empty", QUERY_TEXT_FIELD.getPreferredName()) ); } + if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.modelId())) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName())); + } + return sparseEncodingQueryBuilder; } @@ -207,6 +257,9 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName) ); } + } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + Map queryTokens = parser.map(HashMap::new, XContentParser::floatValue); + sparseEncodingQueryBuilder.queryTokensSupplier(() -> queryTokens); } else { throw new ParsingException( parser.getTokenLocation(), @@ -293,14 +346,14 @@ private static void validateQueryTokens(Map queryTokens) { @Override protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; - if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false; - if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false; + if (Objects.isNull(obj) || getClass() != obj.getClass()) return false; + if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false; + if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) .append(modelId, obj.modelId) .append(maxTokenScore, obj.maxTokenScore); - if (queryTokensSupplier != null) { + if (Objects.nonNull(queryTokensSupplier)) { equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); } return equalsBuilder.isEquals(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 90b635e7c..1b7b85606 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -14,6 +14,7 @@ import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TOKENS_FIELD; import java.io.IOException; import java.util.List; @@ -23,6 +24,7 @@ import java.util.function.BiConsumer; import java.util.function.Supplier; +import org.apache.commons.lang.StringUtils; import org.apache.lucene.document.FeatureField; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -95,6 +97,32 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); } + @SneakyThrows + public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_tokens": { + "token_a": float_score_a, + "token_b": float_score_b + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS_SUPPLIER.get()) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TOKENS_SUPPLIER.get(), sparseEncodingQueryBuilder.queryTokensSupplier().get()); + } + @SneakyThrows public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { /* @@ -276,13 +304,56 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); } + @SneakyThrows + public void testFromXContent_whenBuildWithEmptyQuery_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), StringUtils.EMPTY) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithEmptyModelId_thenFail() { + /* + { + "VECTOR_FIELD": { + "model_id": "" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), StringUtils.EMPTY) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); + } + @SuppressWarnings("unchecked") @SneakyThrows public void testToXContent() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) .modelId(MODEL_ID) .queryText(QUERY_TEXT) - .maxTokenScore(MAX_TOKEN_SCORE); + .maxTokenScore(MAX_TOKEN_SCORE) + .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -308,15 +379,27 @@ public void testToXContent() { assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0); + Map parsedQueryTokens = (Map) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName()); + assertEquals(QUERY_TOKENS_SUPPLIER.get().keySet(), parsedQueryTokens.keySet()); + for (Map.Entry entry : QUERY_TOKENS_SUPPLIER.get().entrySet()) { + assertEquals(entry.getValue(), parsedQueryTokens.get(entry.getKey()).floatValue(), 0); + } + } + + public void testStreams_whenCurrentVersion_thenSuccess() { + setUpClusterService(Version.CURRENT); + testStreams(); + testStreamsWithQueryTokensOnly(); } public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() { setUpClusterService(Version.V_2_12_0); testStreams(); + testStreamsWithQueryTokensOnly(); } @SneakyThrows - public void testStreams() { + private void testStreams() { NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); @@ -356,6 +439,26 @@ public void testStreams() { assertEquals(original, copy); } + @SneakyThrows + private void testStreamsWithQueryTokensOnly() { + NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); + original.fieldName(FIELD_NAME); + original.queryTokensSupplier(QUERY_TOKENS_SUPPLIER); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput); + assertEquals(original, copy); + } + public void testHashAndEquals() { String fieldName1 = "field 1"; String fieldName2 = "field 2"; @@ -459,6 +562,18 @@ public void testHashAndEquals() { .queryName(queryName1) .queryTokensSupplier(() -> queryTokens2); + // Identical to sparseEncodingQueryBuilder_baseline except null query text + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except null model id + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .boost(boost1) + .queryName(queryName1); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -491,6 +606,12 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens); assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullQueryText); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullQueryText.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullModelId); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullModelId.hashCode()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java index d43d252b9..f4b1a29c8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.query; import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import static org.opensearch.neuralsearch.TestUtils.createRandomTokenWeightMap; import static org.opensearch.neuralsearch.TestUtils.objectToFloat; import java.util.List; @@ -76,6 +78,44 @@ public void testBasicQueryUsingQueryText() { } } + /** + * Tests basic query with boost: + * { + * "query": { + * "neural_sparse": { + * "text_sparse": { + * "query_tokens": { + * "hello": float, + * "world": float, + * "a": float, + * "b": float, + * "c": float + * }, + * "boost": 2 + * } + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryUsingQueryTokens() { + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + Map queryTokens = createRandomTokenWeightMap(TEST_TOKENS); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryTokensSupplier(() -> queryTokens) + .boost(2.0f); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(testRankFeaturesDoc, sparseEncodingQueryBuilder.queryTokensSupplier().get()); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, null, null); + } + } + /** * Tests rescore query: * { @@ -180,11 +220,8 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { * "model_id": "dcsdcasd" * } * }, - * "neural_sparse": { - * "field2": { - * "query_text": "Hello world a b", - * "model_id": "dcsdcasd" - * } + * "match": { + * "field2": "Hello world a b", * } * ] * }