From cc4afbf298022cb6f1201cf193ed9e845cc16637 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 23 Sep 2024 17:32:28 -0700 Subject: [PATCH] Initial version for rescorer Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/bwc/HybridSearchIT.java | 34 ++- .../search/query/HybridCollectorManager.java | 116 ++++++-- .../query/HybridQueryPhaseSearcher.java | 4 +- .../neuralsearch/query/HybridQuerySortIT.java | 64 +++++ .../query/HybridCollectorManagerTests.java | 270 ++++++++++++++++++ 6 files changed, 448 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf8d14139..675ea5983 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907)) +- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index eeae7f7dd..35ee47b0d 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -9,6 +9,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; + import org.opensearch.index.query.MatchQueryBuilder; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; @@ -17,6 +19,8 @@ import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -31,6 +35,8 @@ public class HybridSearchIT extends AbstractRollingUpgradeTestCase { private static final String TEXT_UPGRADED = "Hi earth"; private static final String QUERY = "Hi world"; private static final int NUM_DOCS_PER_ROUND = 1; + private static final String VECTOR_EMBEDDING_FIELD = "passage_embedding"; + protected static final String RESCORE_QUERY = "hi"; private static String modelId = ""; // Test rolling-upgrade normalization processor when index with multiple shards @@ -62,12 +68,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr if (isFirstMixedRound()) { totalDocsCountMixed = NUM_DOCS_PER_ROUND; HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder); + QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer); addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); } else { totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null); } break; case UPGRADED: @@ -77,9 +84,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr loadModel(modelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder); + QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault()); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); } finally { wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); } @@ -89,15 +97,19 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr } } - private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder) - throws Exception { + private void validateTestIndexOnUpgrade( + final int numberOfDocs, + final String modelId, + HybridQueryBuilder hybridQueryBuilder, + QueryBuilder rescorer + ) throws Exception { int docCount = getDocCount(getIndexNameForTest()); assertEquals(numberOfDocs, docCount); loadModel(modelId); Map searchResponseAsMap = search( getIndexNameForTest(), hybridQueryBuilder, - null, + rescorer, 1, Map.of("search_pipeline", SEARCH_PIPELINE_NAME) ); @@ -113,18 +125,18 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod private HybridQueryBuilder getQueryBuilder( final String modelId, final Map methodParameters, - final RescoreContext rescoreContext + final RescoreContext rescoreContextForNeuralQuery ) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); - neuralQueryBuilder.fieldName("passage_embedding"); + neuralQueryBuilder.fieldName(VECTOR_EMBEDDING_FIELD); neuralQueryBuilder.modelId(modelId); neuralQueryBuilder.queryText(QUERY); neuralQueryBuilder.k(5); if (methodParameters != null) { neuralQueryBuilder.methodParameters(methodParameters); } - if (rescoreContext != null) { - neuralQueryBuilder.rescoreContext(rescoreContext); + if (Objects.nonNull(rescoreContextForNeuralQuery)) { + neuralQueryBuilder.rescoreContext(rescoreContextForNeuralQuery); } MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY); diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 4eb49e845..a0d444fe6 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -6,6 +6,7 @@ import java.util.Locale; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; @@ -18,6 +19,7 @@ import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; +import org.opensearch.OpenSearchException; import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; @@ -33,6 +35,7 @@ import org.opensearch.search.query.MultiCollectorWrapper; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; import java.io.IOException; @@ -55,6 +58,7 @@ * In most cases it will be wrapped in MultiCollectorManager. */ @RequiredArgsConstructor +@Log4j2 public abstract class HybridCollectorManager implements CollectorManager { private final int numHits; @@ -67,6 +71,7 @@ public abstract class HybridCollectorManager implements CollectorManager getSearchResults(final List results = new ArrayList<>(); DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); for (HybridSearchCollector collector : hybridSearchCollectors) { - TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats); + boolean isSortEnabled = docValueFormats != null; + TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, isSortEnabled); results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats)); } return results; } - private TopDocsAndMaxScore getTopDocsAndAndMaxScore( - final HybridSearchCollector hybridSearchCollector, - final DocValueFormat[] docValueFormats - ) { - TopDocs newTopDocs; + private TopDocsAndMaxScore getTopDocsAndAndMaxScore(final HybridSearchCollector hybridSearchCollector, final boolean isSortEnabled) { List topDocs = hybridSearchCollector.topDocs(); - if (docValueFormats != null) { - newTopDocs = getNewTopFieldDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), - topDocs, - sortAndFormats.sort.getSort() - ); - } else { - newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs); + if (isSortEnabled) { + return getSortedTopDocsAndMaxScore(topDocs, hybridSearchCollector); + } + return getTopDocsAndMaxScore(topDocs, hybridSearchCollector); + } + + private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List topDocs, HybridSearchCollector hybridSearchCollector) { + TopDocs sortedTopDocs = getNewTopFieldDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), + topDocs, + sortAndFormats.sort.getSort() + ); + return new TopDocsAndMaxScore(sortedTopDocs, hybridSearchCollector.getMaxScore()); + } + + private TopDocsAndMaxScore getTopDocsAndMaxScore(List topDocs, HybridSearchCollector hybridSearchCollector) { + List rescoredTopDocs = rescore(topDocs); + float maxScore = calculateMaxScore(rescoredTopDocs, hybridSearchCollector.getMaxScore()); + TopDocs finalTopDocs = getNewTopDocs( + getTotalHits(this.trackTotalHitsUpTo, rescoredTopDocs, hybridSearchCollector.getTotalHits()), + rescoredTopDocs + ); + return new TopDocsAndMaxScore(finalTopDocs, maxScore); + } + + private List rescore(List topDocs) { + List rescoreContexts = searchContext.rescore(); + boolean shouldRescore = Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty(); + if (!shouldRescore) { + return topDocs; + } + List rescoredTopDocs = topDocs; + for (RescoreContext ctx : rescoreContexts) { + rescoredTopDocs = rescoredTopDocs(ctx, rescoredTopDocs); } - return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore()); + return rescoredTopDocs; + } + + /** + * Rescores the top documents using the provided context. The input topDocs may be modified during this process. + */ + private List rescoredTopDocs(final RescoreContext ctx, final List topDocs) { + List result = new ArrayList<>(topDocs.size()); + for (TopDocs topDoc : topDocs) { + try { + result.add(ctx.rescorer().rescore(topDoc, searchContext.searcher(), ctx)); + } catch (IOException exception) { + log.error("rescore failed for hybrid query", exception); + throw new OpenSearchException("rescore failed", exception); + } + } + return result; + } + + /** + * Calculates the maximum score from the provided TopDocs, considering rescoring. + */ + private float calculateMaxScore(List topDocsList, float initialMaxScore) { + List rescoreContexts = searchContext.rescore(); + if (Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty()) { + for (TopDocs topDocs : topDocsList) { + if (Objects.nonNull(topDocs.scoreDocs) && topDocs.scoreDocs.length > 0) { + // first top doc for each sub-query has the max score because top docs are sorted by score desc + initialMaxScore = Math.max(initialMaxScore, topDocs.scoreDocs[0].score); + } + } + } + return initialMaxScore; } private List getHybridSearchCollectors(final Collection collectors) { @@ -415,18 +473,18 @@ public HybridCollectorNonConcurrentManager( int numHits, HitsThresholdChecker hitsThresholdChecker, int trackTotalHitsUpTo, - SortAndFormats sortAndFormats, Weight filteringWeight, - ScoreDoc searchAfter + SearchContext searchContext ) { super( numHits, hitsThresholdChecker, trackTotalHitsUpTo, - sortAndFormats, + searchContext.sort(), filteringWeight, - new TopDocsMerger(sortAndFormats), - (FieldDoc) searchAfter + new TopDocsMerger(searchContext.sort()), + searchContext.searchAfter(), + searchContext ); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @@ -453,18 +511,18 @@ public HybridCollectorConcurrentSearchManager( int numHits, HitsThresholdChecker hitsThresholdChecker, int trackTotalHitsUpTo, - SortAndFormats sortAndFormats, Weight filteringWeight, - ScoreDoc searchAfter + SearchContext searchContext ) { super( numHits, hitsThresholdChecker, trackTotalHitsUpTo, - sortAndFormats, + searchContext.sort(), filteringWeight, - new TopDocsMerger(sortAndFormats), - (FieldDoc) searchAfter + new TopDocsMerger(searchContext.sort()), + searchContext.searchAfter(), + searchContext ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 8c7390406..411127507 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -66,7 +66,9 @@ public boolean searchWith( } Query hybridQuery = extractHybridQuery(searchContext, query); QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); - return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + // we decide on rescore later in collector manager + return false; } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index e6440cc61..b5e812780 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -13,6 +13,7 @@ import lombok.SneakyThrows; import org.junit.BeforeClass; import org.opensearch.client.ResponseException; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -467,6 +468,69 @@ public void testSearchAfter_whenAfterFieldIsNotPassed_thenFail() { } } + @SneakyThrows + public void testSortingWithRescoreWhenConcurrentSegmentSearchEnabledAndDisabled_whenBothSortAndRescorePresent_thenFail() { + try { + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + + List searchAfter = new ArrayList<>(); + searchAfter.add(25); + + QueryBuilder rescoreQuery = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, TEXT_FIELD_VALUE_1_DUNES); + + assertThrows( + "Cannot use [sort] option in conjunction with [rescore].", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + rescoreQuery, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + searchAfter, + 0 + ) + ); + + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + + assertThrows( + "Cannot use [sort] option in conjunction with [rescore].", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + rescoreQuery, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + searchAfter, + 0 + ) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index de9c6006b..2426db4e9 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -12,16 +12,19 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; @@ -54,11 +57,16 @@ import java.util.List; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; + +import org.opensearch.search.rescore.QueryRescorerBuilder; +import org.opensearch.search.rescore.RescoreContext; +import org.opensearch.search.rescore.RescorerBuilder; import org.opensearch.search.sort.SortAndFormats; public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { @@ -70,6 +78,7 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String QUERY1 = "hello"; private static final String QUERY2 = "hi"; private static final float DELTA_FOR_ASSERTION = 0.001f; + protected static final String QUERY3 = "everyone"; @SneakyThrows public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { @@ -734,4 +743,265 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD reader2.close(); directory2.close(); } + + @SneakyThrows + public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(2); + IndexReaderContext indexReaderContext = mock(IndexReaderContext.class); + when(indexReader.getContext()).thenReturn(indexReaderContext); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + RescorerBuilder rescorerBuilder = new QueryRescorerBuilder(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2)); + RescoreContext rescoreContext = rescorerBuilder.buildContext(mockQueryShardContext); + List rescoreContexts = List.of(rescoreContext); + when(searchContext.rescore()).thenReturn(rescoreContexts); + when(indexReader.leaves()).thenReturn(reader.leaves()); + Weight rescoreWeight = mock(Weight.class); + Scorer rescoreScorer = mock(Scorer.class); + when(rescoreWeight.scorer(any())).thenReturn(rescoreScorer); + when(rescoreScorer.docID()).thenReturn(1); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(rescoreScorer.iterator()).thenReturn(iterator); + when(rescoreScorer.score()).thenReturn(0.9f); + when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE), eq(1f))).thenReturn(rescoreWeight); + + CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager1.newCollector(); + + QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + + Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); + ParsedQuery parsedQuery = new ParsedQuery(pfQuery); + searchContext.parsedQuery(parsedQuery); + when(searchContext.parsedPostFilter()).thenReturn(parsedQuery); + when(indexSearcher.rewrite(pfQuery)).thenReturn(pfQuery); + Weight postFilterWeight = mock(Weight.class); + when(indexSearcher.createWeight(pfQuery, ScoreMode.COMPLETE_NO_SCORES, 1f)).thenReturn(postFilterWeight); + + CollectorManager hybridCollectorManager2 = HybridCollectorManager.createHybridCollectorManager(searchContext); + FilteredCollector filteredCollector = (FilteredCollector) hybridCollectorManager2.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + filteredCollector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + LeafCollector filteredCollectorLeafCollector = filteredCollector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + scorer.score(filteredCollectorLeafCollector, leafReaderContext.reader().getLiveDocs()); + filteredCollectorLeafCollector.finish(); + + Object results1 = hybridCollectorManager1.reduce(List.of()); + Object results2 = hybridCollectorManager2.reduce(List.of()); + + assertNotNull(results1); + assertNotNull(results2); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results1); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(2, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(6, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(maxScore >= scoreDocs[2].score); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[4].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[5].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + // index segment 1 + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.flush(); + w.commit(); + + // index segment 2 + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader1 = DirectoryReader.open(w); + IndexSearcher searcher1 = newSearcher(reader1); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + List leafReaderContexts = reader1.leaves(); + IndexReaderContext indexReaderContext = mock(IndexReaderContext.class); + when(indexReader.getContext()).thenReturn(indexReaderContext); + when(indexReader.leaves()).thenReturn(leafReaderContexts); + // set up rescorer in a way that it boosts second documents from the first segment + RescorerBuilder rescorerBuilder = new QueryRescorerBuilder(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2)); + RescoreContext rescoreContext = rescorerBuilder.buildContext(mockQueryShardContext); + List rescoreContexts = List.of(rescoreContext); + when(searchContext.rescore()).thenReturn(rescoreContexts); + Weight rescoreWeight = mock(Weight.class); + Scorer rescoreScorer = mock(Scorer.class); + when(rescoreWeight.scorer(any())).thenReturn(rescoreScorer); + when(rescoreScorer.docID()).thenReturn(1); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(rescoreScorer.iterator()).thenReturn(iterator); + when(rescoreScorer.score()).thenReturn(0.9f); + when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE), eq(1f))).thenReturn(rescoreWeight); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector1 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + HybridTopScoreDocCollector collector2 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher1, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + + LeafReaderContext leafReaderContext = searcher1.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + // assert that second search hit in result has the max score due to boots from rescorer + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(3, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(8, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(maxScore > scoreDocs[2].score); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[4].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[5].score, DELTA_FOR_ASSERTION); + assertTrue(maxScore > scoreDocs[6].score); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[7].score, DELTA_FOR_ASSERTION); + + // release resources + w.close(); + reader1.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } }