Skip to content

Commit

Permalink
Address review comments, mostly refactorings
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jun 28, 2023
1 parent 3955bfc commit fd9c700
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,10 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
boost = parser.floatValue();
// regular boost functionality is not supported, user should use score normalization methods to manipulate with scores
if (boost != DEFAULT_BOOST) {
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, BOOST_FIELD));
throw new ParsingException(
parser.getTokenLocation(),
String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, BOOST_FIELD)
log.error(
String.format(Locale.ROOT, "[%s] query does not support provided value %.4f for [%s]", NAME, boost, BOOST_FIELD)
);
throw new ParsingException(parser.getTokenLocation(), "[{}] query does not support [{}]", NAME, BOOST_FIELD);
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -96,15 +98,12 @@ public ScoreMode scoreMode() {
* @return
*/
public List<TopDocs> topDocs() {
List<TopDocs> topDocs;
if (compoundScores == null) {
return new ArrayList<>();
}
topDocs = new ArrayList(compoundScores.length);
for (int i = 0; i < compoundScores.length; i++) {
int qTopSize = totalHits[i];
topDocs.add(topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize));
}
final List<TopDocs> topDocs = IntStream.range(0, compoundScores.length)
.mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i]))
.collect(Collectors.toList());
return topDocs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;

import lombok.extern.log4j.Log4j2;

Expand All @@ -36,55 +34,52 @@
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

import com.google.common.annotations.VisibleForTesting;

/**
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the
* upstream standard implementation of searcher is called.
*/
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

private Function<List<TopDocs>, TotalHits> totalHitsSupplier;
private Function<List<TopDocs>, Float> maxScoreSupplier;
protected SortAndFormats sortAndFormats;

public boolean searchWith(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

@VisibleForTesting
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
log.debug(String.format(Locale.ROOT, "searching with custom doc collector, shard %s", searchContext.shardTarget().getShardId()));
log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId());

final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
collectors.addFirst(topDocsFactory);

final IndexReader reader = searchContext.searcher().getIndexReader();
int totalNumDocs = Math.max(0, reader.numDocs());
if (searchContext.size() == 0) {
final TotalHitCountCollector collector = new TotalHitCountCollector();
searcher.search(query, collector);
return false;
}
final IndexReader reader = searchContext.searcher().getIndexReader();
int totalNumDocs = Math.max(0, reader.numDocs());
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean rescore = !searchContext.rescore().isEmpty();
if (rescore) {
assert searchContext.sort() == null;
final boolean shouldRescore = !searchContext.rescore().isEmpty();
if (shouldRescore) {
for (RescoreContext rescoreContext : searchContext.rescore()) {
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
}
Expand All @@ -96,53 +91,58 @@ protected boolean searchWithCollector(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo()))
);
totalHitsSupplier = topDocs -> {
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.size() == 0) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}
return new TotalHits(maxTotalHits, relation);
};
maxScoreSupplier = topDocs -> {
if (topDocs.size() == 0) {
return Float.NaN;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
};
sortAndFormats = searchContext.sort();

searcher.search(query, collector);

if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
queryResult.terminatedEarly(false);
}

setTopDocsInQueryResult(queryResult, collector);
setTopDocsInQueryResult(queryResult, collector, searchContext);

return rescore;
return shouldRescore;
}

void setTopDocsInQueryResult(final QuerySearchResult queryResult, final HybridTopScoreDocCollector collector) {
private void setTopDocsInQueryResult(
final QuerySearchResult queryResult,
final HybridTopScoreDocCollector collector,
final SearchContext searchContext
) {
final List<TopDocs> topDocs = collector.topDocs();
float maxScore = maxScoreSupplier.apply(topDocs);
final TopDocs newTopDocs = new CompoundTopDocs(totalHitsSupplier.apply(topDocs), topDocs);
final float maxScore = getMaxScore(topDocs);
final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats());
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
}

private TotalHits getTotalHits(final SearchContext searchContext, final List<TopDocs> topDocs) {
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.size() == 0) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}
return new TotalHits(maxTotalHits, relation);
}

private float getMaxScore(List<TopDocs> topDocs) {
if (topDocs.size() == 0) {
return Float.NaN;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
}

private DocValueFormat[] getSortValueFormats() {
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.plugin;

import static org.mockito.Mockito.mock;

import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -18,8 +20,6 @@
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.Mockito.mock;

public class NeuralSearchTests extends OpenSearchTestCase {

public void testQuerySpecs() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,30 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery(
assertEquals(termSubQuery.value(), termQueryBuilder.value());
}

/**
* Tests query with boost:
* {
* "query": {
* "hybrid": {
* "queries": [
* {
* "term": {
* "text": "keyword"
* }
* },
* {
* "term": {
* "text": "keyword"
* }
* }
* ],
* "boost" : 2.0
* }
* }
* }
*/
@SneakyThrows
public void testBoost_whenNonDefaultBoostSet_thenFail() {
// create query with 6 sub-queries, which is more than current max allowed
XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder()
.startObject()
.startArray("queries")
Expand Down
Loading

0 comments on commit fd9c700

Please sign in to comment.