Skip to content

Commit

Permalink
Adding merge logic for multiple collector result case
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jun 21, 2024
1 parent 3e4a2ef commit 5c8bcc0
Show file tree
Hide file tree
Showing 7 changed files with 778 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fixed merge logic for multiple collector result case ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.search.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
Expand Down Expand Up @@ -31,11 +32,14 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

import static org.apache.lucene.search.TotalHits.Relation;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement;

/**
* Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits.
Expand All @@ -44,9 +48,9 @@
@RequiredArgsConstructor
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3;
private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
Expand All @@ -62,7 +66,6 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Expand All @@ -83,15 +86,13 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight
Expand Down Expand Up @@ -138,16 +139,36 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
List<ReduceableSearchResult> results = new ArrayList<>();
for (HybridTopScoreDocCollector hybridTopScoreDocCollector : hybridTopScoreDocCollectors) {
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());

results.add((QuerySearchResult result) -> {
// this is case of first collector, query result object doesn't have any top docs set, so we can
// just set new top docs without merge
if (result.hasConsumedTopDocs()) {
result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats));
return;
}
// in this case top docs are already present in result, and we need to merge next result object with what we have.
// if collector doesn't have any hits we can just skip it and save some cycles by not doing merge
if (newTopDocs.totalHits.value == 0) {
return;
}
// we need to do actual merge because query result and current collector both have some score hits
TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
result.topDocs(
mergeTopDocsAndMaxScores(originalTotalDocsAndHits, topDocsAndMaxScore),
getSortValueFormats(sortAndFormats)
);
});
}
return reduceCollectorResults(results);
}
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
}
Expand Down Expand Up @@ -195,15 +216,10 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(
int trackTotalHitsUpTo,
final List<TopDocs> topDocs,
final boolean isSingleShard,
final long maxTotalHits
) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final long maxTotalHits) {
final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? Relation.GREATER_THAN_OR_EQUAL_TO
: Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
Expand All @@ -215,6 +231,109 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private ReduceableSearchResult reduceCollectorResults(List<ReduceableSearchResult> results) {
return (result) -> {
for (ReduceableSearchResult r : results) {
r.reduce(result);
}
};
}

@VisibleForTesting
protected TopDocsAndMaxScore mergeTopDocsAndMaxScores(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) {
if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) {
return source;
}
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
ScoreDoc[] sourceScoreDocs = source.topDocs.scoreDocs;
ScoreDoc[] newScoreDocs = newTopDocs.topDocs.scoreDocs;

List<ScoreDoc> mergedScoreDocs = mergedScoreDocs(sourceScoreDocs, newScoreDocs, Comparator.comparing((scoreDoc) -> scoreDoc.score));
TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs);
TopDocsAndMaxScore result = new TopDocsAndMaxScore(
new TopDocs(mergedTotalHits, mergedScoreDocs.toArray(new ScoreDoc[0])),
Math.max(source.maxScore, newTopDocs.maxScore)
);
return result;
}

/**
* Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects.
* Logic is based on assumption that hits of every sub-query are sorted by score.
* Method returns new object and doesn't mutate original ScoreDocs arrays.
* @param sourceScoreDocs original score docs from query result
* @param newScoreDocs new score docs that we need to merge into existing scores
* @return merged array of ScoreDocs objects
*/
private List<ScoreDoc> mergedScoreDocs(
final ScoreDoc[] sourceScoreDocs,
final ScoreDoc[] newScoreDocs,
final Comparator<ScoreDoc> scoreDocComparator
) {
if (Objects.requireNonNull(sourceScoreDocs).length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC
|| Objects.requireNonNull(newScoreDocs).length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) {
throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements");
}
// we overshoot and preallocate more than we need - length of both top docs combined.
// we will take only portion of the array at the end
List<ScoreDoc> mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length);
int sourcePointer = 0;
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
// new pointer is set to 1 as we don't care about it start-stop element
int newPointer = 1;

while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) {
// every iteration is for results of one sub-query
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
newPointer++;
// simplest case when both arrays have results for sub-query
while (sourcePointer < sourceScoreDocs.length
&& isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])
&& newPointer < newScoreDocs.length
&& isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
if (scoreDocComparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
} else {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// at least one object got exhausted at this point, now merge all elements from object that's left
while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
}
while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]);
return mergedScoreDocs;
}

private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) {
// merged value is a lower bound - if both are equal_to than merged will also be equal_to,
// otherwise assign greater_than_or_equal
Relation mergedHitsRelation = source.topDocs.totalHits.relation == Relation.GREATER_THAN_OR_EQUAL_TO
|| newTopDocs.topDocs.totalHits.relation == Relation.GREATER_THAN_OR_EQUAL_TO
? Relation.GREATER_THAN_OR_EQUAL_TO
: Relation.EQUAL_TO;
return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation);
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand All @@ -225,12 +344,11 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager
public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
}

Expand All @@ -255,12 +373,11 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag
public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,28 @@ public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) {
public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0;
}

/**
* Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores
* @param scoreDoc score doc object to check on
* @return true if it is a special element
*/
public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc);
}

/**
* Checking if passed scoreDocs object is a document score element
* @param scoreDoc score doc object to check on
* @return true if element has score
*/
public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return !isHybridQuerySpecialElement(scoreDoc);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,43 +100,43 @@ protected boolean preserveClusterUponCompletion() {

@SneakyThrows
public void testPipelineAggs_whenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
testAvgSumMinMaxAggs();
}

@SneakyThrows
public void testPipelineAggs_whenConcurrentSearchDisabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
testAvgSumMinMaxAggs();
}

@SneakyThrows
public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
testMaxAggsOnSingleShardCluster();
}

@SneakyThrows
public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchDisabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
testMaxAggsOnSingleShardCluster();
}

@SneakyThrows
public void testBucketAndNestedAggs_whenConcurrentSearchDisabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
testDateRange();
}

@SneakyThrows
public void testBucketAndNestedAggs_whenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
testDateRange();
}

@SneakyThrows
public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);

try {
prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);
Expand Down Expand Up @@ -177,14 +177,14 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu

@SneakyThrows
public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
testPostFilterWithSimpleHybridQuery(false, true);
testPostFilterWithComplexHybridQuery(false, true);
}

@SneakyThrows
public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
testPostFilterWithSimpleHybridQuery(false, true);
testPostFilterWithComplexHybridQuery(false, true);
}
Expand Down Expand Up @@ -420,14 +420,14 @@ private void testAvgSumMinMaxAggs() {

@SneakyThrows
public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
testPostFilterWithSimpleHybridQuery(true, true);
testPostFilterWithComplexHybridQuery(true, true);
}

@SneakyThrows
public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
testPostFilterWithSimpleHybridQuery(true, true);
testPostFilterWithComplexHybridQuery(true, true);
}
Expand Down
Loading

0 comments on commit 5c8bcc0

Please sign in to comment.