Skip to content

Commit

Permalink
Adding search processor for score normalization and combination (#227)
Browse files Browse the repository at this point in the history
* Adding search processor for score normalization and combination

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 3, 2023
1 parent d87b63e commit fdec5fa
Show file tree
Hide file tree
Showing 22 changed files with 2,127 additions and 5 deletions.
6 changes: 5 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,12 @@ testClusters.integTest {
// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx1g")
// enable hybrid search for testing

// enable features for testing
// hybrid search
systemProperty('neural_search_hybrid_search_enabled', 'true')
// search pipelines
systemProperty('opensearch.experimental.feature.search_pipeline.enabled', 'true')
}

// Remote Integration Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.Optional;
import java.util.function.Supplier;

import lombok.extern.log4j.Log4j2;

import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -24,18 +26,27 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand All @@ -45,7 +56,8 @@
/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {
@Log4j2
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin {
/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
Expand All @@ -54,6 +66,9 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
@VisibleForTesting
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();;

@Override
public Collection<Object> createComponents(
Expand All @@ -70,6 +85,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}

Expand All @@ -90,9 +106,21 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
log.info("Registering hybrid query phase searcher with feature flag [%]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
return Optional.of(new HybridQueryPhaseSearcher());
}
log.info("Not registering hybrid query phase searcher because feature flag [%] is disabled", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
// we want feature be disabled by default due to risk of colliding and breaking concurrent search in core
return Optional.empty();
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(
Parameters parameters
) {
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

/**
* Processor for score normalization and combination on post query search results. Updates query results with
* normalized and combined scores for next phase (typically it's FETCH)
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
private final String description;
private final ScoreNormalizationTechnique normalizationTechnique;
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
if (shouldRunProcessor(searchPhaseResult)) {
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
}

@Override
public SearchPhaseName getBeforePhase() {
return SearchPhaseName.QUERY;
}

@Override
public SearchPhaseName getAfterPhase() {
return SearchPhaseName.FETCH;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getTag() {
return tag;
}

@Override
public String getDescription() {
return description;
}

@Override
public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldRunProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer)) {
return true;
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
return isNotHybridQuery(optionalSearchPhaseResult);
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> optionalSearchPhaseResult) {
return optionalSearchPhaseResult.isEmpty()
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult())
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs())
|| !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;

import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.query.QuerySearchResult;

/**
* Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data
* and post-processing of final results
*/
@AllArgsConstructor
public class NormalizationProcessorWorkflow {

private final ScoreNormalizer scoreNormalizer;
private final ScoreCombiner scoreCombiner;

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// pre-process data
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

// normalize
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

// combine
scoreCombiner.combineScores(queryTopDocs, combinationTechnique);

// post-process data
updateOriginalQueryResults(querySearchResults, queryTopDocs);
}

/**
* Getting list of CompoundTopDocs from list of QuerySearchResult. Each CompoundTopDocs is for individual shard
* @param querySearchResults collection of QuerySearchResult for all shards
* @return collection of CompoundTopDocs, one object for each shard
*/
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.collect(Collectors.toList());
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import lombok.NoArgsConstructor;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
@NoArgsConstructor
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
private static final Float ZERO_SCORE = 0.0f;

/**
* Arithmetic mean method for combining scores.
* cscore = (score1 + score2 +...+ scoreN)/N
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (float score : scores) {
if (score >= 0.0) {
combinedScore += score;
count++;
}
}
if (count == 0) {
return ZERO_SCORE;
}
return combinedScore / count;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.Map;
import java.util.Optional;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts creation of exact score combination method based on technique name
*/
public class ScoreCombinationFactory {

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique();

private final Map<String, ScoreCombinationTechnique> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
new ArithmeticMeanScoreCombinationTechnique()
);

/**
* Get score combination method by technique name
* @param technique name of technique
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

public interface ScoreCombinationTechnique {
/**
* Defines combination function specific to this technique
* @param scores array of collected original scores
* @return combined score
*/
float combine(final float[] scores);
}
Loading

0 comments on commit fdec5fa

Please sign in to comment.