Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Query phase searcher #204

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import org.opensearch.client.Client;
Expand All @@ -26,13 +27,15 @@
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

Expand Down Expand Up @@ -74,4 +77,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
} else {
if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
// regular boost functionality is not supported, user should use score normalization methods to manipulate with scores
if (boost != DEFAULT_BOOST) {
log.error(
String.format(Locale.ROOT, "[%s] query does not support provided value %.4f for [%s]", NAME, boost, BOOST_FIELD)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just add the boost value in exception message itself and remove this logging?

Suggested change
String.format(Locale.ROOT, "[%s] query does not support provided value %.4f for [%s]", NAME, boost, BOOST_FIELD)
"[{}] query does not support provided value {} for [{}]", NAME, boost, BOOST_FIELD

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, as for the more detailed error message in exception, I rather not do it, got recommendation several times from security team not to put anything user provided to the user facing content, say if this exception is propagated to UI something like this can cause execution of malicious script.

Copy link
Collaborator

@heemin32 heemin32 Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. However, we already validated the input is floatValue in this case right? There will be security risk only when we return not-validated customer input.

Also, there won't be much benefit of logging boost value here as well imo.

);
throw new ParsingException(parser.getTokenLocation(), "[{}] query does not support [{}]", NAME, BOOST_FIELD);
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.search;

import java.util.Arrays;
import java.util.List;

import lombok.Getter;
import lombok.ToString;
Expand All @@ -21,23 +22,23 @@
public class CompoundTopDocs extends TopDocs {

@Getter
private TopDocs[] compoundTopDocs;
private List<TopDocs> compoundTopDocs;

public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) {
super(totalHits, scoreDocs);
}

public CompoundTopDocs(TotalHits totalHits, TopDocs[] docs) {
public CompoundTopDocs(TotalHits totalHits, List<TopDocs> docs) {
// we pass clone of score docs from the sub-query that has most hits
super(totalHits, cloneLargestScoreDocs(docs));
this.compoundTopDocs = docs;
}

private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
private static ScoreDoc[] cloneLargestScoreDocs(List<TopDocs> docs) {
if (docs == null) {
return null;
}
ScoreDoc[] maxScoreDocs = null;
ScoreDoc[] maxScoreDocs = new ScoreDoc[0];
int maxLength = -1;
for (TopDocs topDoc : docs) {
if (topDoc == null || topDoc.scoreDocs == null) {
Expand All @@ -48,9 +49,6 @@ private static ScoreDoc[] cloneLargestScoreDocs(TopDocs[] docs) {
maxScoreDocs = topDoc.scoreDocs;
}
}
if (maxScoreDocs == null) {
return null;
}
// do deep copy
return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
package org.opensearch.neuralsearch.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
Expand All @@ -31,9 +35,7 @@
public class HybridTopScoreDocCollector implements Collector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private float minCompetitiveScore;
private final HitsThresholdChecker hitsThresholdChecker;
private ScoreDoc pqTop;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
private final int numOfHits;
Expand All @@ -48,15 +50,13 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
docBase = context.docBase;
minCompetitiveScore = 0f;

return new TopScoreDocCollector.ScorerLeafCollector() {
HybridQueryScorer compoundQueryScorer;

@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
updateMinCompetitiveScore(scorer);
compoundQueryScorer = (HybridQueryScorer) scorer;
}

Expand Down Expand Up @@ -93,30 +93,17 @@ public ScoreMode scoreMode() {
return hitsThresholdChecker.scoreMode();
}

protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
if (hitsThresholdChecker.isThresholdReached() && pqTop != null && pqTop.score != Float.NEGATIVE_INFINITY) { // -Infinity is the
// boundary score
// we have multiple identical doc id and collect in doc id order, we need next float
float localMinScore = Math.nextUp(pqTop.score);
if (localMinScore > minCompetitiveScore) {
scorer.setMinCompetitiveScore(localMinScore);
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
minCompetitiveScore = localMinScore;
}
}
}

/**
* Get resulting collection of TopDocs for hybrid query after we ran search for each of its sub query
* @return
*/
public TopDocs[] topDocs() {
TopDocs[] topDocs = new TopDocs[compoundScores.length];
for (int i = 0; i < compoundScores.length; i++) {
int qTopSize = totalHits[i];
TopDocs topDocsPerQuery = topDocsPerQuery(0, Math.min(qTopSize, compoundScores[i].size()), compoundScores[i], qTopSize);
topDocs[i] = topDocsPerQuery;
public List<TopDocs> topDocs() {
if (compoundScores == null) {
return new ArrayList<>();
}
final List<TopDocs> topDocs = IntStream.range(0, compoundScores.length)
.mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i]))
.collect(Collectors.toList());
return topDocs;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.search.query;

import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;

import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.TopDocsCollectorContext;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

import com.google.common.annotations.VisibleForTesting;

/**
* Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the
* upstream standard implementation of searcher is called.
*/
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher {

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

@VisibleForTesting
protected boolean searchWithCollector(
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId());

final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
collectors.addFirst(topDocsFactory);
if (searchContext.size() == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to the top of method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if that's possible, in a caller method we're checking for the query type and only call this one in case of specific query. This check may ruin logic for other query types and doing actual search with TotalHitCountCollector is not its responsibility

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, you mean this line should be called no matter what?

final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
collectors.addFirst(topDocsFactory);

final TotalHitCountCollector collector = new TotalHitCountCollector();
searcher.search(query, collector);
return false;
}
final IndexReader reader = searchContext.searcher().getIndexReader();
int totalNumDocs = Math.max(0, reader.numDocs());
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean shouldRescore = !searchContext.rescore().isEmpty();
if (shouldRescore) {
for (RescoreContext rescoreContext : searchContext.rescore()) {
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
}
}

final QuerySearchResult queryResult = searchContext.queryResult();

final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo()))
);

searcher.search(query, collector);

if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
queryResult.terminatedEarly(false);
}

setTopDocsInQueryResult(queryResult, collector, searchContext);

return shouldRescore;
}

private void setTopDocsInQueryResult(
final QuerySearchResult queryResult,
final HybridTopScoreDocCollector collector,
final SearchContext searchContext
) {
final List<TopDocs> topDocs = collector.topDocs();
final float maxScore = getMaxScore(topDocs);
final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
}

private TotalHits getTotalHits(final SearchContext searchContext, final List<TopDocs> topDocs) {
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.size() == 0) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}
return new TotalHits(maxTotalHits, relation);
}

private float getMaxScore(List<TopDocs> topDocs) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private float getMaxScore(List<TopDocs> topDocs) {
private float getMaxScore(final List<TopDocs> topDocs) {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, you mean this line should be called no matter what?
Yes, createTopDocsCollectorContext is core method, they may mutate searchContext. E.g. query result object is set in such a way, it's better keep this sequence similar to what it's in core

if (topDocs.size() == 0) {
return Float.NaN;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
}

private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

@Before
public void setupSettings() {
if (isUpdateClusterSettings()) {
updateClusterSettings();
}
}

protected void updateClusterSettings() {
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
Expand Down Expand Up @@ -514,4 +520,8 @@ protected void deleteModel(String modelId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

public boolean isUpdateClusterSettings() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@

package org.opensearch.neuralsearch.plugin;

import static org.mockito.Mockito.mock;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralSearchTests extends OpenSearchTestCase {
Expand All @@ -23,4 +31,21 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
}

public void testProcessors() {
NeuralSearch plugin = new NeuralSearch();
Processor.Parameters processorParams = mock(Processor.Parameters.class);
Map<String, Processor.Factory> processors = plugin.getProcessors(processorParams);
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
}
}
Loading