diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index fa2584495..21c4e5537 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import org.opensearch.client.Client; @@ -26,6 +27,7 @@ import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; @@ -33,6 +35,7 @@ import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -74,4 +77,9 @@ public Map getProcessors(Processor.Parameters paramet clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); } + + @Override + public Optional getQueryPhaseSearcher() { + return Optional.of(new HybridQueryPhaseSearcher()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index d3a83b6ca..e97c20dec 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -189,6 +189,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx } else { if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { boost = parser.floatValue(); + // regular boost functionality is not supported, user should use score normalization methods to manipulate with scores + if (boost != DEFAULT_BOOST) { + log.error("[{}] query does not support provided value {} for [{}]", NAME, boost, BOOST_FIELD); + throw new ParsingException(parser.getTokenLocation(), "[{}] query does not support [{}]", NAME, BOOST_FIELD); + } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); } else { diff --git a/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java index f54ec9017..fbc820d8b 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java +++ b/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.search; import java.util.Arrays; +import java.util.List; import lombok.Getter; import lombok.ToString; @@ -21,23 +22,23 @@ public class CompoundTopDocs extends TopDocs { @Getter - private TopDocs[] compoundTopDocs; + private List compoundTopDocs; public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) { super(totalHits, scoreDocs); } - public CompoundTopDocs(TotalHits totalHits, TopDocs[] docs) { + public CompoundTopDocs(TotalHits totalHits, List docs) { // we pass clone of score docs from the sub-query that has most hits super(totalHits, cloneLargestScoreDocs(docs)); this.compoundTopDocs = docs; } - private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) { + private static ScoreDoc[] cloneLargestScoreDocs(List docs) { if (docs == null) { return null; } - ScoreDoc[] maxScoreDocs = null; + ScoreDoc[] maxScoreDocs = new ScoreDoc[0]; int maxLength = -1; for (TopDocs topDoc : docs) { if (topDoc == null || topDoc.scoreDocs == null) { @@ -48,9 +49,6 @@ private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) { maxScoreDocs = topDoc.scoreDocs; } } - if (maxScoreDocs == null) { - return null; - } // do deep copy return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index d8829b19e..3fa413826 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -6,7 +6,11 @@ package org.opensearch.neuralsearch.search; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import lombok.Getter; import lombok.extern.log4j.Log4j2; @@ -31,9 +35,7 @@ public class HybridTopScoreDocCollector implements Collector { private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); private int docBase; - private float minCompetitiveScore; private final HitsThresholdChecker hitsThresholdChecker; - private ScoreDoc pqTop; private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; private int[] totalHits; private final int numOfHits; @@ -48,7 +50,6 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { docBase = context.docBase; - minCompetitiveScore = 0f; return new TopScoreDocCollector.ScorerLeafCollector() { HybridQueryScorer compoundQueryScorer; @@ -56,7 +57,6 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept @Override public void setScorer(Scorable scorer) throws IOException { super.setScorer(scorer); - updateMinCompetitiveScore(scorer); compoundQueryScorer = (HybridQueryScorer) scorer; } @@ -93,30 +93,17 @@ public ScoreMode scoreMode() { return hitsThresholdChecker.scoreMode(); } - protected void updateMinCompetitiveScore(Scorable scorer) throws IOException { - if (hitsThresholdChecker.isThresholdReached() && pqTop != null && pqTop.score != Float.NEGATIVE_INFINITY) { // -Infinity is the - // boundary score - // we have multiple identical doc id and collect in doc id order, we need next float - float localMinScore = Math.nextUp(pqTop.score); - if (localMinScore > minCompetitiveScore) { - scorer.setMinCompetitiveScore(localMinScore); - totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; - minCompetitiveScore = localMinScore; - } - } - } - /** * Get resulting collection of TopDocs for hybrid query after we ran search for each of its sub query * @return */ - public TopDocs[] topDocs() { - TopDocs[] topDocs = new TopDocs[compoundScores.length]; - for (int i = 0; i < compoundScores.length; i++) { - int qTopSize = totalHits[i]; - TopDocs topDocsPerQuery = topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize); - topDocs[i] = topDocsPerQuery; + public List topDocs() { + if (compoundScores == null) { + return new ArrayList<>(); } + final List topDocs = IntStream.range(0, compoundScores.length) + .mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i])) + .collect(Collectors.toList()); return topDocs; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java new file mode 100644 index 000000000..81b6b7ebd --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.query; + +import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.TopDocsCollectorContext; +import org.opensearch.search.rescore.RescoreContext; +import org.opensearch.search.sort.SortAndFormats; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the + * upstream standard implementation of searcher is called. + */ +@Log4j2 +public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher { + + public boolean searchWith( + final SearchContext searchContext, + final ContextIndexSearcher searcher, + final Query query, + final LinkedList collectors, + final boolean hasFilterCollector, + final boolean hasTimeout + ) throws IOException { + if (query instanceof HybridQuery) { + return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } + return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } + + @VisibleForTesting + protected boolean searchWithCollector( + final SearchContext searchContext, + final ContextIndexSearcher searcher, + final Query query, + final LinkedList collectors, + final boolean hasFilterCollector, + final boolean hasTimeout + ) throws IOException { + log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + + final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); + collectors.addFirst(topDocsFactory); + if (searchContext.size() == 0) { + final TotalHitCountCollector collector = new TotalHitCountCollector(); + searcher.search(query, collector); + return false; + } + final IndexReader reader = searchContext.searcher().getIndexReader(); + int totalNumDocs = Math.max(0, reader.numDocs()); + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + final boolean shouldRescore = !searchContext.rescore().isEmpty(); + if (shouldRescore) { + for (RescoreContext rescoreContext : searchContext.rescore()) { + numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); + } + } + + final QuerySearchResult queryResult = searchContext.queryResult(); + + final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) + ); + + searcher.search(query, collector); + + if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { + queryResult.terminatedEarly(false); + } + + setTopDocsInQueryResult(queryResult, collector, searchContext); + + return shouldRescore; + } + + private void setTopDocsInQueryResult( + final QuerySearchResult queryResult, + final HybridTopScoreDocCollector collector, + final SearchContext searchContext + ) { + final List topDocs = collector.topDocs(); + final float maxScore = getMaxScore(topDocs); + final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs); + final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); + } + + private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs) { + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.size() == 0) { + return new TotalHits(0, relation); + } + long maxTotalHits = topDocs.get(0).totalHits.value; + for (TopDocs topDoc : topDocs) { + maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); + } + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.size() == 0) { + return Float.NaN; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index bbdfc27e4..59414b49b 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -62,6 +62,12 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { @Before public void setupSettings() { + if (isUpdateClusterSettings()) { + updateClusterSettings(); + } + } + + protected void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); // default threshold for native circuit breaker is 90, it may be not enough on test runner machine updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); @@ -514,4 +520,8 @@ protected void deleteModel(String modelId) { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); } + + public boolean isUpdateClusterSettings() { + return true; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 6bfb7bf33..c2925bf43 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -5,11 +5,19 @@ package org.opensearch.neuralsearch.plugin; +import static org.mockito.Mockito.mock; + import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.test.OpenSearchTestCase; public class NeuralSearchTests extends OpenSearchTestCase { @@ -23,4 +31,21 @@ public void testQuerySpecs() { assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName()))); assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName()))); } + + public void testQueryPhaseSearcher() { + NeuralSearch plugin = new NeuralSearch(); + Optional queryPhaseSearcher = plugin.getQueryPhaseSearcher(); + + assertNotNull(queryPhaseSearcher); + assertFalse(queryPhaseSearcher.isEmpty()); + assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher); + } + + public void testProcessors() { + NeuralSearch plugin = new NeuralSearch(); + Processor.Parameters processorParams = mock(Processor.Parameters.class); + Map processors = plugin.getProcessors(processorParams); + assertNotNull(processors); + assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index b63e43ca4..06b07fed5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; @@ -596,6 +597,109 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery( assertEquals(termSubQuery.value(), termQueryBuilder.value()); } + /** + * Tests query with boost: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "keyword" + * } + * }, + * { + * "term": { + * "text": "keyword" + * } + * } + * ], + * "boost" : 2.0 + * } + * } + * } + */ + @SneakyThrows + public void testBoost_whenNonDefaultBoostSet_thenFail() { + XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .endArray() + .field("boost", 2.0f) + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilderWithNonDefaultBoost.contentType().xContent(), + BytesReference.bytes(xContentBuilderWithNonDefaultBoost) + ); + contentParser.nextToken(); + + ParsingException exception = expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + assertThat(exception.getMessage(), containsString("query does not support [boost]")); + } + + @SneakyThrows + public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { + // create query with 6 sub-queries, which is more than current max allowed + XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .endArray() + .field("boost", DEFAULT_BOOST) + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilderWithNonDefaultBoost.contentType().xContent(), + BytesReference.bytes(xContentBuilderWithNonDefaultBoost) + ); + contentParser.nextToken(); + + HybridQueryBuilder hybridQueryBuilder = HybridQueryBuilder.fromXContent(contentParser); + assertNotNull(hybridQueryBuilder); + } + private Map getInnerMap(Object innerObject, String queryName, String fieldName) { if (!(innerObject instanceof Map)) { fail("field name does not map to nested object"); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 53a9f0dd6..830191156 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -6,15 +6,16 @@ package org.opensearch.neuralsearch.query; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import java.util.stream.IntStream; import lombok.SneakyThrows; @@ -24,6 +25,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; import com.google.common.primitives.Floats; @@ -32,8 +34,9 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; - private static final String TEST_QUERY_TEXT = "Greetings"; - private static final String TEST_QUERY_TEXT2 = "Salute"; + private static final int MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX = 3; + private static final String TEST_QUERY_TEXT = "greetings"; + private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "place"; private static final String TEST_QUERY_TEXT5 = "welcome"; @@ -53,13 +56,26 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; @Before public void setUp() throws Exception { super.setUp(); + updateClusterSettings(); modelId.compareAndSet(null, prepareModel()); } + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + /** * Tests basic query, example of query structure: * { @@ -91,20 +107,27 @@ public void testBasicQuery_whenOneSubQuery_thenSuccessful() { Map searchResponseAsMap1 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilder, 10); - assertEquals(3, getHitCount(searchResponseAsMap1)); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap1)); List> hits1NestedList = getNestedHits(searchResponseAsMap1); - List ids1 = new ArrayList<>(); - List scores1 = new ArrayList<>(); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); for (Map oneHit : hits1NestedList) { - ids1.add((String) oneHit.get("_id")); - scores1.add((Double) oneHit.get("_score")); + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); } // verify that scores are in desc order - assertTrue(IntStream.range(0, scores1.size() - 1).noneMatch(idx -> scores1.get(idx) < scores1.get(idx + 1))); + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); // verify that all ids are unique - assertEquals(Set.copyOf(ids1).size(), ids1.size()); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap1).isPresent()); } @SneakyThrows @@ -125,9 +148,9 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect( hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder); - Map searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 1); + Map searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3); - assertEquals(1, getHitCount(searchResponseAsMap)); + assertEquals(2, getHitCount(searchResponseAsMap)); List> hitsNestedList = getNestedHits(searchResponseAsMap); List scores = new ArrayList<>(); @@ -135,8 +158,16 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect( scores.add((Double) oneHit.get("_score")); } - float expectedScore = computeExpectedScore(modelId.get(), testVector1, TEST_SPACE_TYPE, TEST_QUERY_TEXT) + EXPECTED_SCORE_BM25; - assertEquals(expectedScore, objectToFloat(scores.get(0)), 0.001f); + List expectedScores = List.of( + computeExpectedScore(modelId.get(), testVector1, TEST_SPACE_TYPE, TEST_QUERY_TEXT), + computeExpectedScore(modelId.get(), testVector2, TEST_SPACE_TYPE, TEST_QUERY_TEXT), + computeExpectedScore(modelId.get(), testVector3, TEST_SPACE_TYPE, TEST_QUERY_TEXT) + ); + List actualScores = scores.stream().map(TestUtils::objectToFloat).collect(Collectors.toList()); + assertTrue(expectedScores.containsAll(actualScores)); + + float expectedMaxScore = Math.max(expectedScores.stream().max(Float::compareTo).get(), EXPECTED_SCORE_BM25); + assertEquals(expectedMaxScore, getMaxScore(searchResponseAsMap).get(), 0.001f); } /** @@ -176,8 +207,8 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_2, TEST_QUERY_TEXT4); - TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_3, TEST_QUERY_TEXT5); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); @@ -190,17 +221,23 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { assertEquals(3, getHitCount(searchResponseAsMap1)); List> hits1NestedList = getNestedHits(searchResponseAsMap1); - List ids1 = new ArrayList<>(); - List scores1 = new ArrayList<>(); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); for (Map oneHit : hits1NestedList) { - ids1.add((String) oneHit.get("_id")); - scores1.add((Double) oneHit.get("_score")); + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); } // verify that scores are in desc order - assertTrue(IntStream.range(0, scores1.size() - 1).noneMatch(idx -> scores1.get(idx) < scores1.get(idx + 1))); + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); // verify that all ids are unique - assertEquals(Set.copyOf(ids1).size(), ids1.size()); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); } /** @@ -241,7 +278,7 @@ public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { Map searchResponseAsMap1 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); - assertEquals(3, getHitCount(searchResponseAsMap1)); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap1)); List> hits1NestedList = getNestedHits(searchResponseAsMap1); List ids1 = new ArrayList<>(); @@ -251,6 +288,12 @@ public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { scores1.add((Double) oneHit.get("_score")); } + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + // verify that scores are in desc order assertTrue(IntStream.range(0, scores1.size() - 1).noneMatch(idx -> scores1.get(idx) < scores1.get(idx + 1))); // verify that all ids are unique @@ -263,7 +306,7 @@ public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { Map searchResponseAsMap2 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); - assertEquals(3, getHitCount(searchResponseAsMap2)); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap2)); List ids2 = new ArrayList<>(); List scores2 = new ArrayList<>(); @@ -271,6 +314,12 @@ public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { ids2.add((String) oneHit.get("_id")); scores2.add((Double) oneHit.get("_score")); } + + Map total2 = getTotalHits(searchResponseAsMap2); + assertNotNull(total.get("value")); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total2.get("value")); + assertNotNull(total2.get("relation")); + assertEquals(RELATION_EQUAL_TO, total2.get("relation")); // doc ids must match same from the previous query, order of sub-queries doesn't change the result assertEquals(ids1, ids2); assertEquals(scores1, scores2); @@ -289,44 +338,13 @@ public void test_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderOnlyTerm, 10); assertEquals(0, getHitCount(searchResponseAsMap)); - } + assertTrue(getMaxScore(searchResponseAsMap).isEmpty()); - @SneakyThrows - public void testBoostQuery_whenHybridWithBoost_thenScoreMultipliedCorrectly() { - initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - modelId.get(), - 1, - null, - null - ); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - final float boost = 2.0f; - - HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); - hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); - hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder); - hybridQueryBuilderNeuralThenTerm.boost(boost); - - Map searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 1); - - assertEquals(1, getHitCount(searchResponseAsMap)); - - List> hits1NestedList = getNestedHits(searchResponseAsMap); - List ids1 = new ArrayList<>(); - List scores1 = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { - ids1.add((String) oneHit.get("_id")); - scores1.add((Double) oneHit.get("_score")); - } - - float expectedScore = boost * (computeExpectedScore(modelId.get(), testVector1, TEST_SPACE_TYPE, TEST_QUERY_TEXT) - + EXPECTED_SCORE_BM25); - assertEquals(expectedScore, objectToFloat(scores1.get(0)), 0.001f); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(0, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); } private void initializeIndexIfNotExist(String indexName) throws IOException { @@ -364,7 +382,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { "2", List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), List.of(Floats.asList(testVector2).toArray(), Floats.asList(testVector2).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_2), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT2) ); addKnnDoc( @@ -372,7 +390,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { "3", List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), List.of(Floats.asList(testVector3).toArray(), Floats.asList(testVector3).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_3), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT3) ); assertEquals(3, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); @@ -409,8 +427,18 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { } } - private static List> getNestedHits(Map searchResponseAsMap) { + private List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); } + + private Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index 4a7a0a875..e954004c7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -11,13 +11,9 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashSet; import java.util.List; -import java.util.Set; import java.util.stream.Stream; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -51,8 +47,6 @@ import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public abstract class OpenSearchQueryTestCase extends OpenSearchTestCase { protected final MapperService createMapperService(Version version, XContentBuilder mapping) throws IOException { @@ -139,25 +133,6 @@ protected static Document getDocument(String fieldName, int docId, String fieldV return doc; } - private Pair generateDocuments(int maxDocId) { - final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); - final int[] docs = new int[numDocs]; - final Set uniqueDocs = new HashSet<>(); - while (uniqueDocs.size() < numDocs) { - uniqueDocs.add(random().nextInt(maxDocId)); - } - int i = 0; - for (int doc : uniqueDocs) { - docs[i++] = doc; - } - Arrays.sort(docs); - final float[] scores = new float[numDocs]; - for (int j = 0; j < numDocs; ++j) { - scores[j] = random().nextFloat(); - } - return new ImmutablePair<>(docs, scores); - } - protected static Weight fakeWeight(Query query) { return new Weight(query) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java index 663e023b6..0c79d7f73 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java @@ -5,6 +5,9 @@ package org.opensearch.neuralsearch.search; +import java.util.Arrays; +import java.util.List; + import org.apache.commons.lang3.RandomUtils; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -25,7 +28,7 @@ public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { new ScoreDoc(4, RandomUtils.nextFloat()), new ScoreDoc(5, RandomUtils.nextFloat()) } ); - TopDocs[] topDocs = new TopDocs[] { topDocs1, topDocs2 }; + List topDocs = List.of(topDocs1, topDocs2); CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs); assertNotNull(compoundTopDocs); assertEquals(topDocs, compoundTopDocs.getCompoundTopDocs()); @@ -49,7 +52,7 @@ public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWit new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, RandomUtils.nextFloat()), new ScoreDoc(4, RandomUtils.nextFloat()) } ); - TopDocs[] topDocs = new TopDocs[] { topDocs1, topDocs2 }; + List topDocs = List.of(topDocs1, topDocs2); CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs); assertNotNull(compoundTopDocs); assertNotNull(compoundTopDocs.scoreDocs); @@ -57,15 +60,16 @@ public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWit } public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (TopDocs[]) null); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null); assertNotNull(compoundTopDocs); assertNull(compoundTopDocs.scoreDocs); CompoundTopDocs compoundTopDocsWithNullArray = new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), - new TopDocs[] { null, null } + Arrays.asList(null, null) ); assertNotNull(compoundTopDocsWithNullArray); - assertNull(compoundTopDocsWithNullArray.scoreDocs); + assertNotNull(compoundTopDocsWithNullArray.scoreDocs); + assertEquals(0, compoundTopDocsWithNullArray.scoreDocs.length); } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index 74171cef6..747b05992 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -204,10 +204,10 @@ public void testTopDocs_whenCreateNewAndGetTopDocs_thenSuccessful() { doc = iterator.nextDoc(); } - TopDocs[] topDocs = hybridTopScoreDocCollector.topDocs(); + List topDocs = hybridTopScoreDocCollector.topDocs(); assertNotNull(topDocs); - assertEquals(2, topDocs.length); + assertEquals(2, topDocs.size()); for (TopDocs topDoc : topDocs) { // assert results for each sub-query, there must be correct number of matches, all doc id are correct and scores must be desc @@ -296,14 +296,14 @@ public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() doc = iterator.nextDoc(); } - TopDocs[] topDocs = hybridTopScoreDocCollector.topDocs(); + List topDocs = hybridTopScoreDocCollector.topDocs(); assertNotNull(topDocs); - assertEquals(4, topDocs.length); + assertEquals(4, topDocs.size()); // assert result for each sub query // term query 1 - TopDocs topDocQuery1 = topDocs[0]; + TopDocs topDocQuery1 = topDocs.get(0); assertEquals(docIdsQuery1.length, topDocQuery1.totalHits.value); ScoreDoc[] scoreDocsQuery1 = topDocQuery1.scoreDocs; assertNotNull(scoreDocsQuery1); @@ -314,7 +314,7 @@ public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() List resultDocIdsQuery1 = Arrays.stream(scoreDocsQuery1).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); assertTrue(Arrays.stream(docIdsQuery1).allMatch(resultDocIdsQuery1::contains)); // term query 2 - TopDocs topDocQuery2 = topDocs[1]; + TopDocs topDocQuery2 = topDocs.get(1); assertEquals(docIdsQuery2.length, topDocQuery2.totalHits.value); ScoreDoc[] scoreDocsQuery2 = topDocQuery2.scoreDocs; assertNotNull(scoreDocsQuery2); @@ -325,13 +325,13 @@ public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() List resultDocIdsQuery2 = Arrays.stream(scoreDocsQuery2).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); assertTrue(Arrays.stream(docIdsQuery2).allMatch(resultDocIdsQuery2::contains)); // no match query - TopDocs topDocQuery3 = topDocs[2]; + TopDocs topDocQuery3 = topDocs.get(2); assertEquals(0, topDocQuery3.totalHits.value); ScoreDoc[] scoreDocsQuery3 = topDocQuery3.scoreDocs; assertNotNull(scoreDocsQuery3); assertEquals(0, scoreDocsQuery3.length); // all match query - TopDocs topDocQueryMatchAll = topDocs[3]; + TopDocs topDocQueryMatchAll = topDocs.get(3); assertEquals(docIdsQueryMatchAll.length, topDocQueryMatchAll.totalHits.value); ScoreDoc[] scoreDocsQueryMatchAll = topDocQueryMatchAll.scoreDocs; assertNotNull(scoreDocsQueryMatchAll); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java new file mode 100644 index 000000000..f08dc2332 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -0,0 +1,387 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.query; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import lombok.SneakyThrows; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.Index; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.ShardId; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { + private static final String VECTOR_FIELD_NAME = "vectorField"; + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "This is really nice place to be"; + private static final String QUERY_TEXT1 = "hello"; + private static final String QUERY_TEXT2 = "randomkeyword"; + private static final String QUERY_TEXT3 = "place"; + private static final Index dummyIndex = new Index("dummy", "dummy"); + private static final String MODEL_ID = "mfgfgdsfgfdgsde"; + private static final int K = 10; + private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + + @SneakyThrows + public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null + ); + + SearchContext searchContext = mock(SearchContext.class); + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + + @SneakyThrows + public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null + ); + + SearchContext searchContext = mock(SearchContext.class); + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + + Query query = termSubQuery.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + + @SneakyThrows + public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null + ); + + SearchContext searchContext = mock(SearchContext.class); + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(3); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertEquals(1, topDocs.totalHits.value); + assertTrue(topDocs instanceof CompoundTopDocs); + List compoundTopDocs = ((CompoundTopDocs) topDocs).getCompoundTopDocs(); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + TopDocs subQueryTopDocs = compoundTopDocs.get(0); + assertEquals(1, subQueryTopDocs.totalHits.value); + assertNotNull(subQueryTopDocs.scoreDocs); + assertEquals(1, subQueryTopDocs.scoreDocs.length); + ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[0]; + assertNotNull(scoreDoc); + int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); + assertEquals(docId1, actualDocId); + assertTrue(scoreDoc.score > 0.0f); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridResultsAreSet() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null + ); + + SearchContext searchContext = mock(SearchContext.class); + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.add(QueryBuilders.matchAllQuery()); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertEquals(4, topDocs.totalHits.value); + assertTrue(topDocs instanceof CompoundTopDocs); + List compoundTopDocs = ((CompoundTopDocs) topDocs).getCompoundTopDocs(); + assertNotNull(compoundTopDocs); + assertEquals(3, compoundTopDocs.size()); + + TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); + List expectedIds1 = List.of(docId1); + assertQueryResults(subQueryTopDocs1, expectedIds1, reader); + + TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); + List expectedIds2 = List.of(); + assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + + TopDocs subQueryTopDocs3 = compoundTopDocs.get(2); + List expectedIds3 = List.of(docId1, docId2, docId3, docId4); + assertQueryResults(subQueryTopDocs3, expectedIds3, reader); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { + assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); + assertNotNull(subQueryTopDocs.scoreDocs); + assertEquals(expectedDocIds.size(), subQueryTopDocs.scoreDocs.length); + assertEquals(TotalHits.Relation.EQUAL_TO, subQueryTopDocs.totalHits.relation); + for (int i = 0; i < expectedDocIds.size(); i++) { + int expectedDocId = expectedDocIds.get(i); + ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[i]; + assertNotNull(scoreDoc); + int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); + assertEquals(expectedDocId, actualDocId); + assertTrue(scoreDoc.score > 0.0f); + } + } + + private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { + w.close(); + reader.close(); + directory.close(); + } +}