Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add query phase searcher and basic tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
martin-gaievski committed Jun 27, 2023
1 parent 2de63dd commit db56ae4
Showing 13 changed files with 780 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import org.opensearch.client.Client;
@@ -26,13 +27,15 @@
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

@@ -74,4 +77,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
}
}
Original file line number Diff line number Diff line change
@@ -189,6 +189,14 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
} else {
if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
// regular boost functionality is not supported, user should use score normalization methods to manipulate with scores
if (boost != DEFAULT_BOOST) {
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, BOOST_FIELD));
throw new ParsingException(
parser.getTokenLocation(),
String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, BOOST_FIELD)
);
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.search;

import java.util.Arrays;
import java.util.List;

import lombok.Getter;
import lombok.ToString;
@@ -21,23 +22,23 @@
public class CompoundTopDocs extends TopDocs {

@Getter
private TopDocs[] compoundTopDocs;
private List<TopDocs> compoundTopDocs;

public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) {
super(totalHits, scoreDocs);
}

public CompoundTopDocs(TotalHits totalHits, TopDocs[] docs) {
public CompoundTopDocs(TotalHits totalHits, List<TopDocs> docs) {
// we pass clone of score docs from the sub-query that has most hits
super(totalHits, cloneLargestScoreDocs(docs));
this.compoundTopDocs = docs;
}

private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = null;
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
@@ -48,9 +49,6 @@ private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
maxScoreDocs = topDoc.scoreDocs;
}
}
if (maxScoreDocs == null) {
return null;
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@
package org.opensearch.neuralsearch.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import lombok.Getter;
@@ -31,9 +33,7 @@
public class HybridTopScoreDocCollector implements Collector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private float minCompetitiveScore;
private final HitsThresholdChecker hitsThresholdChecker;
private ScoreDoc pqTop;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
private final int numOfHits;
@@ -48,15 +48,13 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
minCompetitiveScore = 0f;

return new TopScoreDocCollector.ScorerLeafCollector() {
HybridQueryScorer compoundQueryScorer;

@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
updateMinCompetitiveScore(scorer);
compoundQueryScorer = (HybridQueryScorer) scorer;
}

@@ -93,29 +91,19 @@ public ScoreMode scoreMode() {
return hitsThresholdChecker.scoreMode();
}

protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (hitsThresholdChecker.isThresholdReached() && pqTop != null && pqTop.score != Float.NEGATIVE_INFINITY) { // -Infinity is the
// boundary score
// we have multiple identical doc id and collect in doc id order, we need next float
float localMinScore = Math.nextUp(pqTop.score);
if (localMinScore > minCompetitiveScore) {
scorer.setMinCompetitiveScore(localMinScore);
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
minCompetitiveScore = localMinScore;
}
}
}

/**
* Get resulting collection of TopDocs for hybrid query after we ran search for each of its sub query
* @return
*/
public TopDocs[] topDocs() {
TopDocs[] topDocs = new TopDocs[compoundScores.length];
public List<TopDocs> topDocs() {
List<TopDocs> topDocs;
if (compoundScores == null) {
return new ArrayList<>();
}
topDocs = new ArrayList(compoundScores.length);
for (int i = 0; i < compoundScores.length; i++) {
int qTopSize = totalHits[i];
TopDocs topDocsPerQuery = topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize);
topDocs[i] = topDocsPerQuery;
topDocs.add(topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize));
}
return topDocs;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.search.query;

import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;

import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.TopDocsCollectorContext;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

/**
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the
* upstream standard implementation of searcher is called.
*/
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

private Function<List<TopDocs>, TotalHits> totalHitsSupplier;
private Function<List<TopDocs>, Float> maxScoreSupplier;
protected SortAndFormats sortAndFormats;

public boolean searchWith(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
log.debug(String.format(Locale.ROOT, "searching with custom doc collector, shard %s", searchContext.shardTarget().getShardId()));

final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
collectors.addFirst(topDocsFactory);

final IndexReader reader = searchContext.searcher().getIndexReader();
int totalNumDocs = Math.max(0, reader.numDocs());
if (searchContext.size() == 0) {
final TotalHitCountCollector collector = new TotalHitCountCollector();
searcher.search(query, collector);
return false;
}
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean rescore = !searchContext.rescore().isEmpty();
if (rescore) {
assert searchContext.sort() == null;
for (RescoreContext rescoreContext : searchContext.rescore()) {
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
}
}

final QuerySearchResult queryResult = searchContext.queryResult();

final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo()))
);
totalHitsSupplier = topDocs -> {
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.size() == 0) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}
return new TotalHits(maxTotalHits, relation);
};
maxScoreSupplier = topDocs -> {
if (topDocs.size() == 0) {
return Float.NaN;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
};
sortAndFormats = searchContext.sort();

searcher.search(query, collector);

if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
queryResult.terminatedEarly(false);
}

setTopDocsInQueryResult(queryResult, collector);

return rescore;
}

void setTopDocsInQueryResult(final QuerySearchResult queryResult, final HybridTopScoreDocCollector collector) {
final List<TopDocs> topDocs = collector.topDocs();
float maxScore = maxScoreSupplier.apply(topDocs);
final TopDocs newTopDocs = new CompoundTopDocs(totalHitsSupplier.apply(topDocs), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats());
}

private DocValueFormat[] getSortValueFormats() {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
}
Original file line number Diff line number Diff line change
@@ -62,6 +62,12 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

@Before
public void setupSettings() {
if (isUpdateClusterSettings()) {
updateClusterSettings();
}
}

protected void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
@@ -514,4 +520,8 @@ protected void deleteModel(String modelId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

public boolean isUpdateClusterSettings() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -6,10 +6,13 @@
package org.opensearch.neuralsearch.plugin;

import java.util.List;
import java.util.Optional;

import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralSearchTests extends OpenSearchTestCase {
@@ -23,4 +26,13 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
}
}
Loading

0 comments on commit db56ae4

Please sign in to comment.