From 57cf215e08f8d4a7b114d3c935f2cef2103cbf22 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:43:10 -0700 Subject: [PATCH] Fix for missing HybridQuery results when concurrent segment search is enabled (#800) (#804) * Adding merge logic for multiple collector result case Signed-off-by: Martin Gaievski (cherry picked from commit 25d2e82b62b52eae4d73f8f33aab704ff01c0390) Co-authored-by: Martin Gaievski --- CHANGELOG.md | 1 + .../search/query/HybridCollectorManager.java | 101 ++++--- .../query/HybridQueryScoreDocsMerger.java | 83 ++++++ .../search/query/TopDocsMerger.java | 74 +++++ .../util/HybridSearchResultFormatUtil.java | 24 ++ .../query/HybridQueryAggregationsIT.java | 22 +- .../neuralsearch/query/HybridQueryIT.java | 140 +++++++++- .../query/HybridQueryPostFilterIT.java | 8 +- .../BaseAggregationsWithHybridQueryIT.java | 1 - .../BucketAggregationsWithHybridQueryIT.java | 54 ++-- .../MetricAggregationsWithHybridQueryIT.java | 34 +-- ...PipelineAggregationsWithHybridQueryIT.java | 16 +- .../query/HybridCollectorManagerTests.java | 127 +++++++++ .../HybridQueryScoreDocsMergerTests.java | 154 +++++++++++ .../search/query/TopDocsMergerTests.java | 255 ++++++++++++++++++ .../HybridSearchResultFormatUtilTests.java | 25 +- .../neuralsearch/BaseNeuralSearchIT.java | 1 + 17 files changed, 1015 insertions(+), 105 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a72fcdaa..4b6df7e66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 456aa2def..08e7bf657 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -34,6 +34,8 @@ import java.util.List; import java.util.Objects; +import static org.apache.lucene.search.TotalHits.Relation; +import static org.opensearch.neuralsearch.search.query.TopDocsMerger.TOP_DOCS_MERGER_TOP_SCORES; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; @@ -46,12 +48,12 @@ public abstract class HybridCollectorManager implements CollectorManager collectors) { + final List hybridTopScoreDocCollectors = getHybridScoreDocCollectors(collectors); + if (hybridTopScoreDocCollectors.isEmpty()) { + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + List results = new ArrayList<>(); + DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); + for (HybridTopScoreDocCollector hybridTopScoreDocCollector : hybridTopScoreDocCollectors) { + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()), + topDocs + ); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore()); + + results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats, newTopDocs)); + } + return reduceSearchResults(results); + } + + private List getHybridScoreDocCollectors(Collection collectors) { final List hybridTopScoreDocCollectors = new ArrayList<>(); // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper // in case multiple collector managers are registered. We use hybrid scores collector to format scores into @@ -136,20 +156,7 @@ public ReduceableSearchResult reduce(Collection collectors) { hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector()); } } - - if (!hybridTopScoreDocCollectors.isEmpty()) { - HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() - .findFirst() - .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); - List topDocs = hybridTopScoreDocCollector.topDocs(); - TopDocs newTopDocs = getNewTopDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()), - topDocs - ); - TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore()); - return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; - } - throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + return hybridTopScoreDocCollectors; } private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { @@ -195,15 +202,10 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List top return new TopDocs(totalHits, scoreDocs); } - private TotalHits getTotalHits( - int trackTotalHitsUpTo, - final List topDocs, - final boolean isSingleShard, - final long maxTotalHits - ) { - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final long maxTotalHits) { + final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? Relation.GREATER_THAN_OR_EQUAL_TO + : Relation.EQUAL_TO; if (topDocs == null || topDocs.isEmpty()) { return new TotalHits(0, relation); } @@ -215,6 +217,45 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats return sortAndFormats == null ? null : sortAndFormats.formats; } + private void reduceCollectorResults( + QuerySearchResult result, + TopDocsAndMaxScore topDocsAndMaxScore, + DocValueFormat[] docValueFormats, + TopDocs newTopDocs + ) { + // this is case of first collector, query result object doesn't have any top docs set, so we can + // just set new top docs without merge + // this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because + // getter throws exception in case topDocs is null + if (result.hasConsumedTopDocs()) { + result.topDocs(topDocsAndMaxScore, docValueFormats); + return; + } + // in this case top docs are already present in result, and we need to merge next result object with what we have. + // if collector doesn't have any hits we can just skip it and save some cycles by not doing merge + if (newTopDocs.totalHits.value == 0) { + return; + } + // we need to do actual merge because query result and current collector both have some score hits + TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs(); + TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore); + result.topDocs(mergeTopDocsAndMaxScores, docValueFormats); + } + + /** + * For collection of search results, return a single one that has results from all individual result objects. + * @param results collection of search results + * @return single search result that represents all results as one object + */ + private ReduceableSearchResult reduceSearchResults(List results) { + return (result) -> { + for (ReduceableSearchResult r : results) { + // call reduce for results of each single collector, this will update top docs in query result + r.reduce(result); + } + }; + } + /** * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to * use saved state of collector @@ -225,12 +266,11 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager public HybridCollectorNonConcurrentManager( int numHits, HitsThresholdChecker hitsThresholdChecker, - boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight ) { - super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight); + super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @@ -255,12 +295,11 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag public HybridCollectorConcurrentSearchManager( int numHits, HitsThresholdChecker hitsThresholdChecker, - boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight ) { - super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight); + super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java new file mode 100644 index 000000000..7eb6e2b55 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.lucene.search.ScoreDoc; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement; + +/** + * Merges two ScoreDoc arrays into one + */ +@NoArgsConstructor(access = AccessLevel.PACKAGE) +class HybridQueryScoreDocsMerger { + + private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3; + + /** + * Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects. + * Input and output ScoreDocs are in format that is specific to Hybrid Query. This method should not be used for ScoreDocs from + * other query types. + * Logic is based on assumption that hits of every sub-query are sorted by score. + * Method returns new object and doesn't mutate original ScoreDocs arrays. + * @param sourceScoreDocs original score docs from query result + * @param newScoreDocs new score docs that we need to merge into existing scores + * @return merged array of ScoreDocs objects + */ + public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator) { + if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC + || Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) { + throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements"); + } + // we overshoot and preallocate more than we need - length of both top docs combined. + // we will take only portion of the array at the end + List mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length); + int sourcePointer = 0; + // mark beginning of hybrid query results by start element + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + // new pointer is set to 1 as we don't care about it start-stop element + int newPointer = 1; + + while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) { + // every iteration is for results of one sub-query + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + newPointer++; + // simplest case when both arrays have results for sub-query + while (sourcePointer < sourceScoreDocs.length + && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer]) + && newPointer < newScoreDocs.length + && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { + if (comparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) { + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + } else { + mergedScoreDocs.add(newScoreDocs[newPointer]); + newPointer++; + } + } + // at least one object got exhausted at this point, now merge all elements from object that's left + while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) { + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + } + while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { + mergedScoreDocs.add(newScoreDocs[newPointer]); + newPointer++; + } + } + // mark end of hybrid query results by end element + mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]); + return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java new file mode 100644 index 000000000..0e6adfb1a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; + +import java.util.Comparator; +import java.util.Objects; + +/** + * Utility class for merging TopDocs and MaxScore across multiple search queries + */ +@RequiredArgsConstructor(access = AccessLevel.PACKAGE) +class TopDocsMerger { + + private final HybridQueryScoreDocsMerger scoreDocsMerger; + @VisibleForTesting + protected static final Comparator SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); + /** + * Uses hybrid query score docs merger to merge internal score docs + */ + static final TopDocsMerger TOP_DOCS_MERGER_TOP_SCORES = new TopDocsMerger(new HybridQueryScoreDocsMerger<>()); + + /** + * Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. + * @param source TopDocsAndMaxScore for the original query + * @param newTopDocs TopDocsAndMaxScore for the new query + * @return merged TopDocsAndMaxScore object + */ + public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { + return source; + } + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + source.topDocs.scoreDocs, + newTopDocs.topDocs.scoreDocs, + SCORE_DOC_BY_SCORE_COMPARATOR + ); + TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); + TopDocsAndMaxScore result = new TopDocsAndMaxScore( + new TopDocs(mergedTotalHits, mergedScoreDocs), + Math.max(source.maxScore, newTopDocs.maxScore) + ); + return result; + } + + private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + // merged value is a lower bound - if both are equal_to than merged will also be equal_to, + // otherwise assign greater_than_or_equal + TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + || newTopDocs.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java index 162647b11..8fc71056a 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java @@ -52,4 +52,28 @@ public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) { public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) { return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0; } + + /** + * Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores + * @param scoreDoc score doc object to check on + * @return true if it is a special element + */ + public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) { + if (Objects.isNull(scoreDoc)) { + return false; + } + return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc); + } + + /** + * Checking if passed scoreDocs object is a document score element + * @param scoreDoc score doc object to check on + * @return true if element has score + */ + public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) { + if (Objects.isNull(scoreDoc)) { + return false; + } + return !isHybridQuerySpecialElement(scoreDoc); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index 9e72dfcb1..4bc40add8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -100,43 +100,43 @@ protected boolean preserveClusterUponCompletion() { @SneakyThrows public void testPipelineAggs_whenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgSumMinMaxAggs(); } @SneakyThrows public void testPipelineAggs_whenConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAvgSumMinMaxAggs(); } @SneakyThrows public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testMaxAggsOnSingleShardCluster(); } @SneakyThrows public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testMaxAggsOnSingleShardCluster(); } @SneakyThrows public void testBucketAndNestedAggs_whenConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateRange(); } @SneakyThrows public void testBucketAndNestedAggs_whenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateRange(); } @SneakyThrows public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); try { prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); @@ -177,14 +177,14 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPostFilterWithSimpleHybridQuery(false, true); testPostFilterWithComplexHybridQuery(false, true); } @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPostFilterWithSimpleHybridQuery(false, true); testPostFilterWithComplexHybridQuery(false, true); } @@ -420,14 +420,14 @@ private void testAvgSumMinMaxAggs() { @SneakyThrows public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPostFilterWithSimpleHybridQuery(true, true); testPostFilterWithComplexHybridQuery(true, true); } @SneakyThrows public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPostFilterWithSimpleHybridQuery(true, true); testPostFilterWithComplexHybridQuery(true, true); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 43e302698..a650087b4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -22,6 +22,8 @@ import java.util.Set; import java.util.stream.IntStream; +import org.apache.commons.lang.RandomStringUtils; +import org.apache.commons.lang.math.RandomUtils; import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.ResponseException; @@ -47,6 +49,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-hybrid-multi-doc-nested-type-single-shard-index"; private static final String TEST_INDEX_WITH_KEYWORDS_ONE_SHARD = "test-hybrid-keywords-single-shard-index"; + private static final String TEST_INDEX_DOC_QTY_ONE_SHARD = "test-hybrid-doc-qty-single-shard-index"; + private static final String TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS = "test-hybrid-doc-qty-multiple-shards-index"; private static final String TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS = "test-hybrid-keywords-three-shards-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; @@ -76,6 +80,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; + protected static final int SINGLE_SHARD = 1; + protected static final int MULTIPLE_SHARDS = 3; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -692,6 +698,101 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilters_thenSuccess() { } } + @SneakyThrows + public void testConcurrentSearchWithMultipleSlices_whenSingleShardIndex_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + int numberOfDocumentsInIndex = 1_000; + initializeIndexIfNotExist(TEST_INDEX_DOC_QTY_ONE_SHARD, SINGLE_SHARD, numberOfDocumentsInIndex); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.matchAllQuery()); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_DOC_QTY_ONE_SHARD, + hybridQueryBuilder, + null, + numberOfDocumentsInIndex, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + int queryHitCount = getHitCount(firstSearchResponseAsMap); + assertEquals(numberOfDocumentsInIndex, queryHitCount); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(numberOfDocumentsInIndex, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_DOC_QTY_ONE_SHARD, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + + @SneakyThrows + public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + int numberOfDocumentsInIndex = 2_000; + initializeIndexIfNotExist(TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, MULTIPLE_SHARDS, numberOfDocumentsInIndex); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.matchAllQuery()); + hybridQueryBuilder.add(QueryBuilders.rangeQuery(INTEGER_FIELD_PRICE).gte(0).lte(1000)); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + numberOfDocumentsInIndex, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + int queryHitCount = getHitCount(firstSearchResponseAsMap); + assertEquals(numberOfDocumentsInIndex, queryHitCount); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(numberOfDocumentsInIndex, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -784,7 +885,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { buildIndexConfiguration( Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), List.of(TEST_NESTED_TYPE_FIELD_NAME_1), - 1 + SINGLE_SHARD ), "" ); @@ -805,7 +906,14 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_INDEX_WITH_KEYWORDS_ONE_SHARD.equals(indexName) && !indexExists(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD)) { createIndexWithConfiguration( indexName, - buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_PRICE), List.of(KEYWORD_FIELD_1), List.of(), 1), + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(KEYWORD_FIELD_1), + List.of(), + SINGLE_SHARD + ), "" ); addDocWithKeywordsAndIntFields( @@ -901,6 +1009,34 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { } } + @SneakyThrows + private void initializeIndexIfNotExist(String indexName, int numberOfShards, int numberOfDocuments) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(KEYWORD_FIELD_1), + List.of(), + numberOfShards + ), + "" + ); + for (int i = 0; i < numberOfDocuments; i++) { + addDocWithKeywordsAndIntFields( + indexName, + String.valueOf(i), + INTEGER_FIELD_PRICE, + RandomUtils.nextInt(1000), + KEYWORD_FIELD_1, + RandomStringUtils.randomAlphabetic(10) + ); + } + } + } + private void addDocsToIndex(final String testMultiDocIndexName) { addKnnDoc( testMultiDocIndexName, diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 7d33d07fe..8f8ae8cc4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -68,7 +68,7 @@ public static void setUpCluster() { @SneakyThrows public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); @@ -81,7 +81,7 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_the @SneakyThrows public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); @@ -94,7 +94,7 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_th @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); @@ -107,7 +107,7 @@ public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_ @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchDisabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java index 48fb8f8d6..5cc5b9170 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java @@ -77,7 +77,6 @@ public class BaseAggregationsWithHybridQueryIT extends BaseNeuralSearchIT { protected static final String AVG_AGGREGATION_NAME = "avg_field"; protected static final String GENERIC_AGGREGATION_NAME = "my_aggregation"; protected static final String DATE_AGGREGATION_NAME = "date_aggregation"; - protected static final String CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH = "search.concurrent_segment_search.enabled"; @BeforeClass @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java index ce8854eed..7385f48e5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java @@ -68,165 +68,165 @@ public class BucketAggregationsWithHybridQueryIT extends BaseAggregationsWithHyb @SneakyThrows public void testBucketAndNestedAggs_whenAdjacencyMatrix_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAdjacencyMatrixAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenAdjacencyMatrix_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAdjacencyMatrixAggs(); } @SneakyThrows public void testBucketAndNestedAggs_whenDiversifiedSampler_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDiversifiedSampler(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDiversifiedSampler_thenFail() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDiversifiedSampler(); } @SneakyThrows public void testBucketAndNestedAggs_whenAvgNestedIntoFilter_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAvgNestedIntoFilter(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenAvgNestedIntoFilter_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgNestedIntoFilter(); } @SneakyThrows public void testBucketAndNestedAggs_whenSumNestedIntoFilters_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSumNestedIntoFilters(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSumNestedIntoFilters_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSumNestedIntoFilters(); } @SneakyThrows public void testBucketAggs_whenGlobalAggUsedWithQuery_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testGlobalAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenGlobalAggUsedWithQuery_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testGlobalAggs(); } @SneakyThrows public void testBucketAggs_whenHistogramAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testHistogramAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenHistogramAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testHistogramAggs(); } @SneakyThrows public void testBucketAggs_whenNestedAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testNestedAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenNestedAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testNestedAggs(); } @SneakyThrows public void testBucketAggs_whenSamplerAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSampler(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSamplerAgg_thenFail() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSampler(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketScriptAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testMetricAggs_whenTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testTermsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testTermsAggs(); } @SneakyThrows public void testMetricAggs_whenSignificantTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSignificantTermsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSignificantTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSignificantTermsAggs(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java index 844963790..94f1e7207 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java @@ -81,103 +81,103 @@ public class MetricAggregationsWithHybridQueryIT extends BaseAggregationsWithHyb */ @SneakyThrows public void testWithConcurrentSegmentSearch_whenAvgAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgAggs(); } @SneakyThrows public void testMetricAggs_whenCardinalityAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testCardinalityAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenCardinalityAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testCardinalityAggs(); } @SneakyThrows public void testMetricAggs_whenExtendedStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testExtendedStatsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenExtendedStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testExtendedStatsAggs(); } @SneakyThrows public void testMetricAggs_whenTopHitsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testTopHitsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenTopHitsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testTopHitsAggs(); } @SneakyThrows public void testMetricAggs_whenPercentileRank_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPercentileRankAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenPercentileRank_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPercentileRankAggs(); } @SneakyThrows public void testMetricAggs_whenPercentile_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPercentileAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenPercentile_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPercentileAggs(); } @SneakyThrows public void testMetricAggs_whenScriptedMetrics_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testScriptedMetricsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenScriptedMetrics_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testScriptedMetricsAggs(); } @SneakyThrows public void testMetricAggs_whenSumAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSumAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSumAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSumAggs(); } @SneakyThrows public void testMetricAggs_whenValueCount_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testValueCountAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenValueCount_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testValueCountAggs(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java index 168dce1e0..fd118629b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java @@ -53,49 +53,49 @@ public class PipelineAggregationsWithHybridQueryIT extends BaseAggregationsWithH @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketSortAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketSortAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketSortAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketSortAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToCumulativeSumAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToCumulativeSumAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToCumulativeSumAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToCumulativeSumAggs(); } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 1fd67a7ae..40d2ee3f6 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -60,6 +60,7 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; + private static final String QUERY2 = "hi"; private static final float DELTA_FOR_ASSERTION = 0.001f; @SneakyThrows @@ -309,4 +310,130 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { reader.close(); directory.close(); } + + @SneakyThrows + public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector1 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + HybridTopScoreDocCollector collector2 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(2, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(6, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(scoreDocs[2].score > 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertTrue(scoreDocs[4].score > 0); + + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[5].score, DELTA_FOR_ASSERTION); + // we have to assert that one of hits is max score because scores are generated for each run and order is not guaranteed + assertTrue(Float.compare(scoreDocs[2].score, maxScore) == 0 || Float.compare(scoreDocs[4].score, maxScore) == 0); + + w.close(); + reader.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java new file mode 100644 index 000000000..2147578c9 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.ScoreDoc; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import static org.opensearch.neuralsearch.search.query.TopDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; + +public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { + + private static final float DELTA_FOR_ASSERTION = 0.001f; + + public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scores = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + createStartStopElementForHybridSearchResults(2) }; + + NullPointerException exception = assertThrows( + NullPointerException.class, + () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR) + ); + assertEquals("score docs cannot be null", exception.getMessage()); + + exception = assertThrows(NullPointerException.class, () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR)); + assertEquals("score docs cannot be null", exception.getMessage()); + + ScoreDoc[] lessElementsScoreDocs = new ScoreDoc[] { createStartStopElementForHybridSearchResults(2), new ScoreDoc(1, 0.7f) }; + + IllegalArgumentException notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> scoreDocsMerger.merge(lessElementsScoreDocs, scores, SCORE_DOC_BY_SCORE_COMPARATOR) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + + notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> scoreDocsMerger.merge(scores, lessElementsScoreDocs, SCORE_DOC_BY_SCORE_COMPARATOR) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + } + + public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) }; + + ScoreDoc[] scoreDocsNew = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) }; + + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + + assertNotNull(mergedScoreDocs); + assertEquals(10, mergedScoreDocs.length); + + // check format, all elements one by one + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[1].score, 0); + assertScoreDoc(mergedScoreDocs[2], 1, 0.7f); + assertScoreDoc(mergedScoreDocs[3], 0, 0.5f); + assertScoreDoc(mergedScoreDocs[4], 2, 0.3f); + assertScoreDoc(mergedScoreDocs[5], 4, 0.3f); + assertScoreDoc(mergedScoreDocs[6], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[7].score, 0); + assertScoreDoc(mergedScoreDocs[8], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[9].score, 0); + } + + public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) }; + ScoreDoc[] scoreDocsNew = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) }; + + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + + assertNotNull(mergedScoreDocs); + assertEquals(8, mergedScoreDocs.length); + + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[1].score, 0); + assertScoreDoc(mergedScoreDocs[2], 1, 0.7f); + assertScoreDoc(mergedScoreDocs[3], 4, 0.3f); + assertScoreDoc(mergedScoreDocs[4], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[5].score, 0); + assertScoreDoc(mergedScoreDocs[6], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[7].score, 0); + } + + public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) }; + ScoreDoc[] scoreDocsNew = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createStartStopElementForHybridSearchResults(2) }; + + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + + assertNotNull(mergedScoreDocs); + assertEquals(4, mergedScoreDocs.length); + // check format, all elements one by one + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[1].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[2].score, 0); + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[3].score, 0); + } + + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { + assertEquals(expectedDocId, scoreDoc.doc); + assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java new file mode 100644 index 000000000..5a99f3f3a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; + +public class TopDocsMergerTests extends OpenSearchQueryTestCase { + + private static final float DELTA_FOR_ASSERTION = 0.001f; + + @SneakyThrows + public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(6, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 5 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 5 + 1 + 2 + 2 = 10 + assertEquals(10, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 0, 0.5f); + assertScoreDoc(scoreDocs[4], 2, 0.3f); + assertScoreDoc(scoreDocs[5], 4, 0.3f); + assertScoreDoc(scoreDocs[6], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[7].score, 0); + assertScoreDoc(scoreDocs[8], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[9].score, 0); + } + + @SneakyThrows + public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 3 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 3 + 1 + 2 + 2 = 8 + assertEquals(8, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 4, 0.3f); + assertScoreDoc(scoreDocs[4], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[5].score, 0); + assertScoreDoc(scoreDocs[6], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[7].score, 0); + } + + @SneakyThrows + public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0); + TopDocs topDocsNew = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(0, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[2].score, 0); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, 0); + } + + @SneakyThrows + public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(firstMergedTopDocsAndMaxScore); + + // merge results from collector 3 + TopDocs topDocsThirdCollector = new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(3), + createDelimiterElementForHybridSearchResults(3), + new ScoreDoc(3, 0.4f), + createDelimiterElementForHybridSearchResults(3), + new ScoreDoc(7, 0.85f), + new ScoreDoc(9, 0.2f), + createStartStopElementForHybridSearchResults(3) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreThirdCollector = new TopDocsAndMaxScore(topDocsThirdCollector, 0.85f); + TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = topDocsMerger.merge( + firstMergedTopDocsAndMaxScore, + topDocsAndMaxScoreThirdCollector + ); + + assertEquals(0.85f, finalMergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(9, finalMergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, finalMergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 6 from sub-query1 and 3 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 6 + 3 + 2 + 2 = 13 + assertEquals(13, finalMergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = finalMergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 0, 0.5f); + assertScoreDoc(scoreDocs[4], 3, 0.4f); + assertScoreDoc(scoreDocs[5], 2, 0.3f); + assertScoreDoc(scoreDocs[6], 4, 0.3f); + assertScoreDoc(scoreDocs[7], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[8].score, 0); + assertScoreDoc(scoreDocs[9], 7, 0.85f); + assertScoreDoc(scoreDocs[10], 4, 0.6f); + assertScoreDoc(scoreDocs[11], 9, 0.2f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[12].score, 0); + } + + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { + assertEquals(expectedDocId, scoreDoc.doc); + assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java index 16f2f10ce..d84e196cd 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java @@ -4,16 +4,14 @@ */ package org.opensearch.neuralsearch.search.util; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQuerySpecialElement; import org.apache.lucene.search.ScoreDoc; import org.opensearch.common.Randomness; @@ -57,4 +55,23 @@ public void testCreateElements_whenCreateStartStopAndDelimiterElements_thenSucce assertEquals(docId, delimiterElement.doc); assertEquals(MAGIC_NUMBER_DELIMITER, delimiterElement.score, 0.0f); } + + public void testSpecialElementCheck_whenElementIsSpecialAndIsNotSpecial_thenSuccessful() { + int docId = 1; + ScoreDoc startStopElement = new ScoreDoc(docId, MAGIC_NUMBER_START_STOP); + assertTrue(isHybridQuerySpecialElement(startStopElement)); + assertFalse(isHybridQueryScoreDocElement(startStopElement)); + + ScoreDoc delimiterElement = new ScoreDoc(docId, MAGIC_NUMBER_DELIMITER); + assertTrue(isHybridQuerySpecialElement(delimiterElement)); + assertFalse(isHybridQueryScoreDocElement(delimiterElement)); + } + + public void testScoreElementCheck_whenElementIsSpecialAndIsNotSpecial_thenSuccessful() { + int docId = 1; + float score = Randomness.get().nextFloat(); + ScoreDoc startStopElement = new ScoreDoc(docId, score); + assertFalse(isHybridQuerySpecialElement(startStopElement)); + assertTrue(isHybridQueryScoreDocElement(startStopElement)); + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index f0b3ba1af..b3567e1ff 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -83,6 +83,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); + protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; protected final ClassLoader classLoader = this.getClass().getClassLoader();