-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Query phase searcher #204
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,148 @@ | ||||||
/* | ||||||
* Copyright OpenSearch Contributors | ||||||
* SPDX-License-Identifier: Apache-2.0 | ||||||
*/ | ||||||
|
||||||
package org.opensearch.neuralsearch.search.query; | ||||||
|
||||||
import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; | ||||||
|
||||||
import java.io.IOException; | ||||||
import java.util.LinkedList; | ||||||
import java.util.List; | ||||||
|
||||||
import lombok.extern.log4j.Log4j2; | ||||||
|
||||||
import org.apache.lucene.index.IndexReader; | ||||||
import org.apache.lucene.search.Query; | ||||||
import org.apache.lucene.search.ScoreDoc; | ||||||
import org.apache.lucene.search.TopDocs; | ||||||
import org.apache.lucene.search.TotalHitCountCollector; | ||||||
import org.apache.lucene.search.TotalHits; | ||||||
import org.opensearch.common.lucene.search.TopDocsAndMaxScore; | ||||||
import org.opensearch.neuralsearch.query.HybridQuery; | ||||||
import org.opensearch.neuralsearch.search.CompoundTopDocs; | ||||||
import org.opensearch.neuralsearch.search.HitsThresholdChecker; | ||||||
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; | ||||||
import org.opensearch.search.DocValueFormat; | ||||||
import org.opensearch.search.internal.ContextIndexSearcher; | ||||||
import org.opensearch.search.internal.SearchContext; | ||||||
import org.opensearch.search.query.QueryCollectorContext; | ||||||
import org.opensearch.search.query.QueryPhase; | ||||||
import org.opensearch.search.query.QuerySearchResult; | ||||||
import org.opensearch.search.query.TopDocsCollectorContext; | ||||||
import org.opensearch.search.rescore.RescoreContext; | ||||||
import org.opensearch.search.sort.SortAndFormats; | ||||||
|
||||||
import com.google.common.annotations.VisibleForTesting; | ||||||
|
||||||
/** | ||||||
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the | ||||||
* upstream standard implementation of searcher is called. | ||||||
*/ | ||||||
@Log4j2 | ||||||
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher { | ||||||
|
||||||
public boolean searchWith( | ||||||
final SearchContext searchContext, | ||||||
final ContextIndexSearcher searcher, | ||||||
final Query query, | ||||||
final LinkedList<QueryCollectorContext> collectors, | ||||||
final boolean hasFilterCollector, | ||||||
final boolean hasTimeout | ||||||
) throws IOException { | ||||||
if (query instanceof HybridQuery) { | ||||||
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); | ||||||
} | ||||||
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); | ||||||
} | ||||||
|
||||||
@VisibleForTesting | ||||||
protected boolean searchWithCollector( | ||||||
heemin32 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
final SearchContext searchContext, | ||||||
final ContextIndexSearcher searcher, | ||||||
final Query query, | ||||||
final LinkedList<QueryCollectorContext> collectors, | ||||||
final boolean hasFilterCollector, | ||||||
final boolean hasTimeout | ||||||
) throws IOException { | ||||||
log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); | ||||||
|
||||||
final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); | ||||||
collectors.addFirst(topDocsFactory); | ||||||
if (searchContext.size() == 0) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this to the top of method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if that's possible, in a caller method we're checking for the query type and only call this one in case of specific query. This check may ruin logic for other query types and doing actual search with TotalHitCountCollector is not its responsibility There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, you mean this line should be called no matter what?
|
||||||
final TotalHitCountCollector collector = new TotalHitCountCollector(); | ||||||
searcher.search(query, collector); | ||||||
return false; | ||||||
} | ||||||
final IndexReader reader = searchContext.searcher().getIndexReader(); | ||||||
int totalNumDocs = Math.max(0, reader.numDocs()); | ||||||
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); | ||||||
final boolean shouldRescore = !searchContext.rescore().isEmpty(); | ||||||
if (shouldRescore) { | ||||||
for (RescoreContext rescoreContext : searchContext.rescore()) { | ||||||
numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); | ||||||
} | ||||||
} | ||||||
|
||||||
final QuerySearchResult queryResult = searchContext.queryResult(); | ||||||
|
||||||
final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( | ||||||
numDocs, | ||||||
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) | ||||||
); | ||||||
|
||||||
searcher.search(query, collector); | ||||||
|
||||||
if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { | ||||||
queryResult.terminatedEarly(false); | ||||||
} | ||||||
|
||||||
setTopDocsInQueryResult(queryResult, collector, searchContext); | ||||||
|
||||||
return shouldRescore; | ||||||
} | ||||||
|
||||||
private void setTopDocsInQueryResult( | ||||||
final QuerySearchResult queryResult, | ||||||
final HybridTopScoreDocCollector collector, | ||||||
final SearchContext searchContext | ||||||
) { | ||||||
final List<TopDocs> topDocs = collector.topDocs(); | ||||||
final float maxScore = getMaxScore(topDocs); | ||||||
final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs); | ||||||
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); | ||||||
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); | ||||||
} | ||||||
|
||||||
private TotalHits getTotalHits(final SearchContext searchContext, final List<TopDocs> topDocs) { | ||||||
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); | ||||||
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED | ||||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO | ||||||
: TotalHits.Relation.EQUAL_TO; | ||||||
if (topDocs == null || topDocs.size() == 0) { | ||||||
return new TotalHits(0, relation); | ||||||
} | ||||||
long maxTotalHits = topDocs.get(0).totalHits.value; | ||||||
for (TopDocs topDoc : topDocs) { | ||||||
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); | ||||||
} | ||||||
return new TotalHits(maxTotalHits, relation); | ||||||
} | ||||||
|
||||||
private float getMaxScore(List<TopDocs> topDocs) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ack There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
if (topDocs.size() == 0) { | ||||||
return Float.NaN; | ||||||
} else { | ||||||
return topDocs.stream() | ||||||
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) | ||||||
.map(scoreDoc -> scoreDoc.score) | ||||||
.max(Float::compare) | ||||||
.get(); | ||||||
} | ||||||
} | ||||||
|
||||||
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { | ||||||
return sortAndFormats == null ? null : sortAndFormats.formats; | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just add the boost value in exception message itself and remove this logging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, as for the more detailed error message in exception, I rather not do it, got recommendation several times from security team not to put anything user provided to the user facing content, say if this exception is propagated to UI something like this can cause execution of malicious script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. However, we already validated the input is floatValue in this case right? There will be security risk only when we return not-validated customer input.
Also, there won't be much benefit of logging boost value here as well imo.