diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a72fcdaa..995159cc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fixed merge logic for multiple collector result case ([#800](https://github.com/opensearch-project/neural-search/pull/800)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 456aa2def..b4767cc33 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.search.query; +import com.google.common.annotations.VisibleForTesting; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Collector; @@ -31,11 +32,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.List; import java.util.Objects; +import static org.apache.lucene.search.TotalHits.Relation; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement; /** * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. @@ -44,9 +48,9 @@ @RequiredArgsConstructor public abstract class HybridCollectorManager implements CollectorManager { + private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3; private final int numHits; private final HitsThresholdChecker hitsThresholdChecker; - private final boolean isSingleShard; private final int trackTotalHitsUpTo; private final SortAndFormats sortAndFormats; @Nullable @@ -62,7 +66,6 @@ public abstract class HybridCollectorManager implements CollectorManager collectors) { } if (!hybridTopScoreDocCollectors.isEmpty()) { - HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() - .findFirst() - .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); - List topDocs = hybridTopScoreDocCollector.topDocs(); - TopDocs newTopDocs = getNewTopDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()), - topDocs - ); - TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore()); - return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + List results = new ArrayList<>(); + for (HybridTopScoreDocCollector hybridTopScoreDocCollector : hybridTopScoreDocCollectors) { + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()), + topDocs + ); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore()); + + results.add((QuerySearchResult result) -> { + // this is case of first collector, query result object doesn't have any top docs set, so we can + // just set new top docs without merge + if (result.hasConsumedTopDocs()) { + result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); + return; + } + // in this case top docs are already present in result, and we need to merge next result object with what we have. + // if collector doesn't have any hits we can just skip it and save some cycles by not doing merge + if (newTopDocs.totalHits.value == 0) { + return; + } + // we need to do actual merge because query result and current collector both have some score hits + TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs(); + result.topDocs( + mergeTopDocsAndMaxScores(originalTotalDocsAndHits, topDocsAndMaxScore), + getSortValueFormats(sortAndFormats) + ); + }); + } + return reduceCollectorResults(results); } throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); } @@ -195,15 +216,10 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List top return new TopDocs(totalHits, scoreDocs); } - private TotalHits getTotalHits( - int trackTotalHitsUpTo, - final List topDocs, - final boolean isSingleShard, - final long maxTotalHits - ) { - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final long maxTotalHits) { + final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? Relation.GREATER_THAN_OR_EQUAL_TO + : Relation.EQUAL_TO; if (topDocs == null || topDocs.isEmpty()) { return new TotalHits(0, relation); } @@ -215,6 +231,109 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats return sortAndFormats == null ? null : sortAndFormats.formats; } + private ReduceableSearchResult reduceCollectorResults(List results) { + return (result) -> { + for (ReduceableSearchResult r : results) { + r.reduce(result); + } + }; + } + + @VisibleForTesting + protected TopDocsAndMaxScore mergeTopDocsAndMaxScores(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { + return source; + } + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + ScoreDoc[] sourceScoreDocs = source.topDocs.scoreDocs; + ScoreDoc[] newScoreDocs = newTopDocs.topDocs.scoreDocs; + + List mergedScoreDocs = mergedScoreDocs(sourceScoreDocs, newScoreDocs, Comparator.comparing((scoreDoc) -> scoreDoc.score)); + TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); + TopDocsAndMaxScore result = new TopDocsAndMaxScore( + new TopDocs(mergedTotalHits, mergedScoreDocs.toArray(new ScoreDoc[0])), + Math.max(source.maxScore, newTopDocs.maxScore) + ); + return result; + } + + /** + * Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects. + * Logic is based on assumption that hits of every sub-query are sorted by score. + * Method returns new object and doesn't mutate original ScoreDocs arrays. + * @param sourceScoreDocs original score docs from query result + * @param newScoreDocs new score docs that we need to merge into existing scores + * @return merged array of ScoreDocs objects + */ + private List mergedScoreDocs( + final ScoreDoc[] sourceScoreDocs, + final ScoreDoc[] newScoreDocs, + final Comparator scoreDocComparator + ) { + if (Objects.requireNonNull(sourceScoreDocs).length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC + || Objects.requireNonNull(newScoreDocs).length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) { + throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements"); + } + // we overshoot and preallocate more than we need - length of both top docs combined. + // we will take only portion of the array at the end + List mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length); + int sourcePointer = 0; + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + // new pointer is set to 1 as we don't care about it start-stop element + int newPointer = 1; + + while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) { + // every iteration is for results of one sub-query + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + newPointer++; + // simplest case when both arrays have results for sub-query + while (sourcePointer < sourceScoreDocs.length + && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer]) + && newPointer < newScoreDocs.length + && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { + if (scoreDocComparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) { + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + } else { + mergedScoreDocs.add(newScoreDocs[newPointer]); + newPointer++; + } + } + // at least one object got exhausted at this point, now merge all elements from object that's left + while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) { + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + } + while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { + mergedScoreDocs.add(newScoreDocs[newPointer]); + newPointer++; + } + } + mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]); + return mergedScoreDocs; + } + + private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + // merged value is a lower bound - if both are equal_to than merged will also be equal_to, + // otherwise assign greater_than_or_equal + Relation mergedHitsRelation = source.topDocs.totalHits.relation == Relation.GREATER_THAN_OR_EQUAL_TO + || newTopDocs.topDocs.totalHits.relation == Relation.GREATER_THAN_OR_EQUAL_TO + ? Relation.GREATER_THAN_OR_EQUAL_TO + : Relation.EQUAL_TO; + return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation); + } + /** * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to * use saved state of collector @@ -225,12 +344,11 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager public HybridCollectorNonConcurrentManager( int numHits, HitsThresholdChecker hitsThresholdChecker, - boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight ) { - super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight); + super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @@ -255,12 +373,11 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag public HybridCollectorConcurrentSearchManager( int numHits, HitsThresholdChecker hitsThresholdChecker, - boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight ) { - super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight); + super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java index 162647b11..8fc71056a 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java @@ -52,4 +52,28 @@ public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) { public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) { return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0; } + + /** + * Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores + * @param scoreDoc score doc object to check on + * @return true if it is a special element + */ + public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) { + if (Objects.isNull(scoreDoc)) { + return false; + } + return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc); + } + + /** + * Checking if passed scoreDocs object is a document score element + * @param scoreDoc score doc object to check on + * @return true if element has score + */ + public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) { + if (Objects.isNull(scoreDoc)) { + return false; + } + return !isHybridQuerySpecialElement(scoreDoc); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index 9e72dfcb1..4bc40add8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -100,43 +100,43 @@ protected boolean preserveClusterUponCompletion() { @SneakyThrows public void testPipelineAggs_whenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgSumMinMaxAggs(); } @SneakyThrows public void testPipelineAggs_whenConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAvgSumMinMaxAggs(); } @SneakyThrows public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testMaxAggsOnSingleShardCluster(); } @SneakyThrows public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testMaxAggsOnSingleShardCluster(); } @SneakyThrows public void testBucketAndNestedAggs_whenConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateRange(); } @SneakyThrows public void testBucketAndNestedAggs_whenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateRange(); } @SneakyThrows public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); try { prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); @@ -177,14 +177,14 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPostFilterWithSimpleHybridQuery(false, true); testPostFilterWithComplexHybridQuery(false, true); } @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPostFilterWithSimpleHybridQuery(false, true); testPostFilterWithComplexHybridQuery(false, true); } @@ -420,14 +420,14 @@ private void testAvgSumMinMaxAggs() { @SneakyThrows public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPostFilterWithSimpleHybridQuery(true, true); testPostFilterWithComplexHybridQuery(true, true); } @SneakyThrows public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPostFilterWithSimpleHybridQuery(true, true); testPostFilterWithComplexHybridQuery(true, true); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 43e302698..a650087b4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -22,6 +22,8 @@ import java.util.Set; import java.util.stream.IntStream; +import org.apache.commons.lang.RandomStringUtils; +import org.apache.commons.lang.math.RandomUtils; import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.ResponseException; @@ -47,6 +49,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-hybrid-multi-doc-nested-type-single-shard-index"; private static final String TEST_INDEX_WITH_KEYWORDS_ONE_SHARD = "test-hybrid-keywords-single-shard-index"; + private static final String TEST_INDEX_DOC_QTY_ONE_SHARD = "test-hybrid-doc-qty-single-shard-index"; + private static final String TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS = "test-hybrid-doc-qty-multiple-shards-index"; private static final String TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS = "test-hybrid-keywords-three-shards-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; @@ -76,6 +80,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; + protected static final int SINGLE_SHARD = 1; + protected static final int MULTIPLE_SHARDS = 3; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -692,6 +698,101 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilters_thenSuccess() { } } + @SneakyThrows + public void testConcurrentSearchWithMultipleSlices_whenSingleShardIndex_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + int numberOfDocumentsInIndex = 1_000; + initializeIndexIfNotExist(TEST_INDEX_DOC_QTY_ONE_SHARD, SINGLE_SHARD, numberOfDocumentsInIndex); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.matchAllQuery()); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_DOC_QTY_ONE_SHARD, + hybridQueryBuilder, + null, + numberOfDocumentsInIndex, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + int queryHitCount = getHitCount(firstSearchResponseAsMap); + assertEquals(numberOfDocumentsInIndex, queryHitCount); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(numberOfDocumentsInIndex, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_DOC_QTY_ONE_SHARD, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + + @SneakyThrows + public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + int numberOfDocumentsInIndex = 2_000; + initializeIndexIfNotExist(TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, MULTIPLE_SHARDS, numberOfDocumentsInIndex); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.matchAllQuery()); + hybridQueryBuilder.add(QueryBuilders.rangeQuery(INTEGER_FIELD_PRICE).gte(0).lte(1000)); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + numberOfDocumentsInIndex, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + int queryHitCount = getHitCount(firstSearchResponseAsMap); + assertEquals(numberOfDocumentsInIndex, queryHitCount); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(numberOfDocumentsInIndex, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -784,7 +885,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { buildIndexConfiguration( Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), List.of(TEST_NESTED_TYPE_FIELD_NAME_1), - 1 + SINGLE_SHARD ), "" ); @@ -805,7 +906,14 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_INDEX_WITH_KEYWORDS_ONE_SHARD.equals(indexName) && !indexExists(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD)) { createIndexWithConfiguration( indexName, - buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_PRICE), List.of(KEYWORD_FIELD_1), List.of(), 1), + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(KEYWORD_FIELD_1), + List.of(), + SINGLE_SHARD + ), "" ); addDocWithKeywordsAndIntFields( @@ -901,6 +1009,34 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { } } + @SneakyThrows + private void initializeIndexIfNotExist(String indexName, int numberOfShards, int numberOfDocuments) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(KEYWORD_FIELD_1), + List.of(), + numberOfShards + ), + "" + ); + for (int i = 0; i < numberOfDocuments; i++) { + addDocWithKeywordsAndIntFields( + indexName, + String.valueOf(i), + INTEGER_FIELD_PRICE, + RandomUtils.nextInt(1000), + KEYWORD_FIELD_1, + RandomStringUtils.randomAlphabetic(10) + ); + } + } + } + private void addDocsToIndex(final String testMultiDocIndexName) { addKnnDoc( testMultiDocIndexName, diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 1fd67a7ae..daae2b4fa 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -20,6 +20,7 @@ import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.Query; @@ -50,6 +51,8 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; @@ -60,6 +63,7 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; + private static final String QUERY2 = "hi"; private static final float DELTA_FOR_ASSERTION = 0.001f; @SneakyThrows @@ -309,4 +313,459 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { reader.close(); directory.close(); } + + @SneakyThrows + public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector1 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + HybridTopScoreDocCollector collector2 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(2, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(6, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(scoreDocs[2].score > 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertTrue(scoreDocs[4].score > 0); + + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[5].score, DELTA_FOR_ASSERTION); + // we have to assert that one of hits is max score because scores are generated for each run and order is not guaranteed + assertTrue(Float.compare(scoreDocs[2].score, maxScore) == 0 || Float.compare(scoreDocs[4].score, maxScore) == 0); + + w.close(); + reader.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } + + @SneakyThrows + public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + HybridCollectorManager hybridCollectorManager = (HybridCollectorManager) HybridCollectorManager.createHybridCollectorManager( + searchContext + ); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = hybridCollectorManager.mergeTopDocsAndMaxScores( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew + ); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(6, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 5 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 5 + 1 + 2 + 2 = 10 + assertEquals(10, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 0, 0.5f); + assertScoreDoc(scoreDocs[4], 2, 0.3f); + assertScoreDoc(scoreDocs[5], 4, 0.3f); + assertScoreDoc(scoreDocs[6], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[7].score, 0); + assertScoreDoc(scoreDocs[8], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[9].score, 0); + } + + @SneakyThrows + public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + HybridCollectorManager hybridCollectorManager = (HybridCollectorManager) HybridCollectorManager.createHybridCollectorManager( + searchContext + ); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = hybridCollectorManager.mergeTopDocsAndMaxScores( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew + ); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 3 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 3 + 1 + 2 + 2 = 8 + assertEquals(8, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 4, 0.3f); + assertScoreDoc(scoreDocs[4], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[5].score, 0); + assertScoreDoc(scoreDocs[6], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[7].score, 0); + } + + @SneakyThrows + public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + HybridCollectorManager hybridCollectorManager = (HybridCollectorManager) HybridCollectorManager.createHybridCollectorManager( + searchContext + ); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0); + TopDocs topDocsNew = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = hybridCollectorManager.mergeTopDocsAndMaxScores( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew + ); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(0, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[2].score, 0); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, 0); + } + + @SneakyThrows + public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + HybridCollectorManager hybridCollectorManager = (HybridCollectorManager) HybridCollectorManager.createHybridCollectorManager( + searchContext + ); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = hybridCollectorManager.mergeTopDocsAndMaxScores( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew + ); + + assertNotNull(firstMergedTopDocsAndMaxScore); + + // merge results from collector 3 + TopDocs topDocsThirdCollector = new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(3), + createDelimiterElementForHybridSearchResults(3), + new ScoreDoc(3, 0.4f), + createDelimiterElementForHybridSearchResults(3), + new ScoreDoc(7, 0.85f), + new ScoreDoc(9, 0.2f), + createStartStopElementForHybridSearchResults(3) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreThirdCollector = new TopDocsAndMaxScore(topDocsThirdCollector, 0.85f); + TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = hybridCollectorManager.mergeTopDocsAndMaxScores( + firstMergedTopDocsAndMaxScore, + topDocsAndMaxScoreThirdCollector + ); + + assertEquals(0.85f, finalMergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(9, finalMergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, finalMergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 6 from sub-query1 and 3 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 6 + 3 + 2 + 2 = 13 + assertEquals(13, finalMergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = finalMergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 0, 0.5f); + assertScoreDoc(scoreDocs[4], 3, 0.4f); + assertScoreDoc(scoreDocs[5], 2, 0.3f); + assertScoreDoc(scoreDocs[6], 4, 0.3f); + assertScoreDoc(scoreDocs[7], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[8].score, 0); + assertScoreDoc(scoreDocs[9], 7, 0.85f); + assertScoreDoc(scoreDocs[10], 4, 0.6f); + assertScoreDoc(scoreDocs[11], 9, 0.2f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[12].score, 0); + } + + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { + assertEquals(expectedDocId, scoreDoc.doc); + assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java index 16f2f10ce..d84e196cd 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java @@ -4,16 +4,14 @@ */ package org.opensearch.neuralsearch.search.util; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQuerySpecialElement; import org.apache.lucene.search.ScoreDoc; import org.opensearch.common.Randomness; @@ -57,4 +55,23 @@ public void testCreateElements_whenCreateStartStopAndDelimiterElements_thenSucce assertEquals(docId, delimiterElement.doc); assertEquals(MAGIC_NUMBER_DELIMITER, delimiterElement.score, 0.0f); } + + public void testSpecialElementCheck_whenElementIsSpecialAndIsNotSpecial_thenSuccessful() { + int docId = 1; + ScoreDoc startStopElement = new ScoreDoc(docId, MAGIC_NUMBER_START_STOP); + assertTrue(isHybridQuerySpecialElement(startStopElement)); + assertFalse(isHybridQueryScoreDocElement(startStopElement)); + + ScoreDoc delimiterElement = new ScoreDoc(docId, MAGIC_NUMBER_DELIMITER); + assertTrue(isHybridQuerySpecialElement(delimiterElement)); + assertFalse(isHybridQueryScoreDocElement(delimiterElement)); + } + + public void testScoreElementCheck_whenElementIsSpecialAndIsNotSpecial_thenSuccessful() { + int docId = 1; + float score = Randomness.get().nextFloat(); + ScoreDoc startStopElement = new ScoreDoc(docId, score); + assertFalse(isHybridQuerySpecialElement(startStopElement)); + assertTrue(isHybridQueryScoreDocElement(startStopElement)); + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 2682ee7c7..689e4bf98 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -83,6 +83,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); + protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; protected final ClassLoader classLoader = this.getClass().getClassLoader();