From 185050a32a150f7b538531f924d1e4d00498ba98 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Sat, 26 Aug 2023 00:24:17 +0200 Subject: [PATCH] [Backport 2.x] Added Score Normalization and Combination feature, manual backport (#263) * Added Score Normalization and Combination feature (#241) Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + build.gradle | 4 + .../neuralsearch/plugin/NeuralSearch.java | 63 +- .../processor/NormalizationProcessor.java | 122 +++ .../NormalizationProcessorWorkflow.java | 88 +++ ...ithmeticMeanScoreCombinationTechnique.java | 57 ++ ...eometricMeanScoreCombinationTechnique.java | 57 ++ ...HarmonicMeanScoreCombinationTechnique.java | 54 ++ .../combination/ScoreCombinationFactory.java | 52 ++ .../ScoreCombinationTechnique.java | 16 + .../combination/ScoreCombinationUtil.java | 80 ++ .../processor/combination/ScoreCombiner.java | 150 ++++ .../NormalizationProcessorFactory.java | 94 +++ .../L2ScoreNormalizationTechnique.java | 91 +++ .../MinMaxScoreNormalizationTechnique.java | 117 +++ .../ScoreNormalizationFactory.java | 34 + .../ScoreNormalizationTechnique.java | 22 + .../normalization/ScoreNormalizer.java | 29 + .../neuralsearch/query/HybridQuery.java | 164 ++++ .../query/HybridQueryBuilder.java | 299 ++++++++ .../neuralsearch/query/HybridQueryScorer.java | 150 ++++ .../neuralsearch/query/HybridQueryWeight.java | 117 +++ .../neuralsearch/search/CompoundTopDocs.java | 55 ++ .../search/HitsThresholdChecker.java | 44 ++ .../search/HybridTopScoreDocCollector.java | 149 ++++ .../query/HybridQueryPhaseSearcher.java | 148 ++++ .../settings/NeuralSearchSettings.java | 29 + .../opensearch/neuralsearch/TestUtils.java | 139 ++++ .../common/BaseNeuralSearchIT.java | 196 ++++- .../plugin/NeuralSearchTests.java | 76 ++ .../processor/NormalizationProcessorIT.java | 389 ++++++++++ .../NormalizationProcessorTests.java | 252 ++++++ .../NormalizationProcessorWorkflowTests.java | 69 ++ .../processor/ScoreCombinationIT.java | 412 ++++++++++ .../ScoreCombinationTechniqueTests.java | 86 +++ .../processor/ScoreNormalizationIT.java | 295 +++++++ .../ScoreNormalizationTechniqueTests.java | 231 ++++++ .../processor/TextEmbeddingProcessorIT.java | 14 + ...ticMeanScoreCombinationTechniqueTests.java | 82 ++ .../BaseScoreCombinationTechniqueTests.java | 96 +++ ...ricMeanScoreCombinationTechniqueTests.java | 98 +++ ...nicMeanScoreCombinationTechniqueTests.java | 80 ++ .../ScoreCombinationFactoryTests.java | 49 ++ .../NormalizationProcessorFactoryTests.java | 388 ++++++++++ .../L2ScoreNormalizationTechniqueTests.java | 225 ++++++ ...inMaxScoreNormalizationTechniqueTests.java | 178 +++++ .../ScoreNormalizationFactoryTests.java | 38 + .../query/HybridQueryBuilderTests.java | 719 ++++++++++++++++++ .../neuralsearch/query/HybridQueryIT.java | 410 ++++++++++ .../query/HybridQueryScorerTests.java | 209 +++++ .../neuralsearch/query/HybridQueryTests.java | 278 +++++++ .../query/HybridQueryWeightTests.java | 122 +++ .../neuralsearch/query/NeuralQueryIT.java | 54 +- .../query/OpenSearchQueryTestCase.java | 232 ++++++ .../search/CompoundTopDocsTests.java | 75 ++ .../search/HitsTresholdCheckerTests.java | 33 + .../HybridTopScoreDocCollectorTests.java | 352 +++++++++ .../query/HybridQueryPhaseSearcherTests.java | 418 ++++++++++ 58 files changed, 8556 insertions(+), 25 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java create mode 100644 src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f31d25198..1824d209f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.7...2.x) ### Features +* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/build.gradle b/build.gradle index 0961eb890..5147114e4 100644 --- a/build.gradle +++ b/build.gradle @@ -253,6 +253,10 @@ testClusters.integTest { // Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due // to ml-commons memory circuit breaker exception jvmArgs("-Xms1g", "-Xmx1g") + + // enable features for testing + // hybrid search + systemProperty('plugins.neural_search.hybrid_search_enabled', 'true') } // Remote Integration Tests diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 6bf85e3b8..b46d2bc6d 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -5,15 +5,23 @@ package org.opensearch.neuralsearch.plugin; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED; + +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; +import lombok.extern.log4j.Log4j2; + import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; @@ -21,25 +29,40 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; /** * Neural Search plugin class */ -public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin { - +@Log4j2 +public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { private MLCommonsClientAccessor clientAccessor; + private NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();; @Override public Collection createComponents( @@ -56,12 +79,15 @@ public Collection createComponents( final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); + normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } + @Override public List> getQueries() { - return Collections.singletonList( - new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent) + return Arrays.asList( + new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent), + new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent) ); } @@ -70,4 +96,33 @@ public Map getProcessors(Processor.Parameters paramet clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); } + + @Override + public Optional getQueryPhaseSearcher() { + if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey())) { + log.info("Registering hybrid query phase searcher with feature flag [{}]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey()); + return Optional.of(new HybridQueryPhaseSearcher()); + } + log.info( + "Not registering hybrid query phase searcher because feature flag [{}] is disabled", + NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey() + ); + // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core + return Optional.empty(); + } + + @Override + public Map> getSearchPhaseResultsProcessors( + Parameters parameters + ) { + return Map.of( + NormalizationProcessor.TYPE, + new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory) + ); + } + + @Override + public List> getSettings() { + return List.of(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java new file mode 100644 index 000000000..570d3b9e1 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.query.QuerySearchResult; + +/** + * Processor for score normalization and combination on post query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + */ +@Log4j2 +@AllArgsConstructor +public class NormalizationProcessor implements SearchPhaseResultsProcessor { + public static final String TYPE = "normalization-processor"; + + private final String tag; + private final String description; + private final ScoreNormalizationTechnique normalizationTechnique; + private final ScoreCombinationTechnique combinationTechnique; + private final NormalizationProcessorWorkflow normalizationWorkflow; + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + if (shouldSkipProcessor(searchPhaseResult)) { + log.debug("Query results are not compatible with normalization processor"); + return; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + + private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer)) { + return true; + } + + QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; + Optional optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .filter(Objects::nonNull) + .findFirst(); + return isNotHybridQuery(optionalSearchPhaseResult); + } + + private boolean isNotHybridQuery(final Optional optionalSearchPhaseResult) { + return optionalSearchPhaseResult.isEmpty() + || Objects.isNull(optionalSearchPhaseResult.get().queryResult()) + || Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs()) + || !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs); + } + + private List getQueryPhaseSearchResults( + final SearchPhaseResults results + ) { + return results.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java new file mode 100644 index 000000000..fda095773 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.query.QuerySearchResult; + +/** + * Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data + * and post-processing of final results + */ +@AllArgsConstructor +@Log4j2 +public class NormalizationProcessorWorkflow { + + private final ScoreNormalizer scoreNormalizer; + private final ScoreCombiner scoreCombiner; + + /** + * Start execution of this workflow + * @param querySearchResults input data with QuerySearchResult from multiple shards + * @param normalizationTechnique technique for score normalization + * @param combinationTechnique technique for score combination + */ + public void execute( + final List querySearchResults, + final ScoreNormalizationTechnique normalizationTechnique, + final ScoreCombinationTechnique combinationTechnique + ) { + // pre-process data + log.debug("Pre-process query results"); + List queryTopDocs = getQueryTopDocs(querySearchResults); + + // normalize + log.debug("Do score normalization"); + scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + + // combine + log.debug("Do score combination"); + scoreCombiner.combineScores(queryTopDocs, combinationTechnique); + + // post-process data + log.debug("Post-process query results after score normalization and combination"); + updateOriginalQueryResults(querySearchResults, queryTopDocs); + } + + /** + * Getting list of CompoundTopDocs from list of QuerySearchResult. Each CompoundTopDocs is for individual shard + * @param querySearchResults collection of QuerySearchResult for all shards + * @return collection of CompoundTopDocs, one object for each shard + */ + private List getQueryTopDocs(final List querySearchResults) { + List queryTopDocs = querySearchResults.stream() + .filter(searchResult -> Objects.nonNull(searchResult.topDocs())) + .filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs) + .map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs) + .collect(Collectors.toList()); + return queryTopDocs; + } + + private void updateOriginalQueryResults(final List querySearchResults, final List queryTopDocs) { + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) { + continue; + } + CompoundTopDocs updatedTopDocs = queryTopDocs.get(i); + float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f; + TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore); + querySearchResult.topDocs(updatedTopDocsAndMaxScore, null); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..cfafeb3e5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import lombok.ToString; + +/** + * Abstracts combination of scores based on arithmetic mean method + */ +@ToString(onlyExplicitlyIncluded = true) +public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "arithmetic_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; + + public ArithmeticMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); + } + + /** + * Arithmetic mean method for combining scores. + * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) + * + * Zero (0.0) scores are excluded from number of scores N + */ + @Override + public float combine(final float[] scores) { + float combinedScore = 0.0f; + float sumOfWeights = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score >= 0.0) { + float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); + score = score * weight; + combinedScore += score; + sumOfWeights += weight; + } + } + if (sumOfWeights == 0.0f) { + return ZERO_SCORE; + } + return combinedScore / sumOfWeights; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..4e7a8ca9e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import lombok.ToString; + +/** + * Abstracts combination of scores based on geometrical mean method + */ +@ToString(onlyExplicitlyIncluded = true) +public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "geometric_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; + + public GeometricMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); + } + + /** + * Weighted geometric mean method for combining scores. + * + * We use formula below to calculate mean. It's based on fact that logarithm of geometric mean is the + * weighted arithmetic mean of the logarithms of individual scores. + * + * geometric_mean = exp(sum(weight_1*ln(score_1) + .... + weight_n*ln(score_n))/sum(weight_1 + ... + weight_n)) + */ + @Override + public float combine(final float[] scores) { + float weightedLnSum = 0; + float sumOfWeights = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; + } + float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); + sumOfWeights += weight; + weightedLnSum += weight * Math.log(score); + } + return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..9f913b2ef --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import lombok.ToString; + +/** + * Abstracts combination of scores based on harmonic mean method + */ +@ToString(onlyExplicitlyIncluded = true) +public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "harmonic_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; + + public HarmonicMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); + } + + /** + * Weighted harmonic mean method for combining scores. + * score = sum(weight_1 + .... + weight_n)/sum(weight_1/score_1 + ... + weight_n/score_n) + * + * Zero (0.0) scores are excluded from number of scores N + */ + @Override + public float combine(final float[] scores) { + float sumOfWeights = 0; + float sumOfHarmonics = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score <= 0) { + continue; + } + float weightOfSubQuery = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); + sumOfWeights += weightOfSubQuery; + sumOfHarmonics += weightOfSubQuery / score; + } + return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java new file mode 100644 index 000000000..f05d24823 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +/** + * Abstracts creation of exact score combination method based on technique name + */ +public class ScoreCombinationFactory { + private static final ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique( + Map.of(), + scoreCombinationUtil + ); + + private final Map, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of( + ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil), + HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), + GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) + ); + + /** + * Get score combination method by technique name + * @param technique name of technique + * @return instance of ScoreCombinationTechnique for technique name + */ + public ScoreCombinationTechnique createCombination(final String technique) { + return createCombination(technique, Map.of()); + } + + /** + * Get score combination method by technique name + * @param technique name of technique + * @param params parameters that combination technique may use + * @return instance of ScoreCombinationTechnique for technique name + */ + public ScoreCombinationTechnique createCombination(final String technique, final Map params) { + return Optional.ofNullable(scoreCombinationMethodsMap.get(technique)) + .orElseThrow(() -> new IllegalArgumentException("provided combination technique is not supported")) + .apply(params); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java new file mode 100644 index 000000000..6e0a5db65 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +public interface ScoreCombinationTechnique { + + /** + * Defines combination function specific to this technique + * @param scores array of collected original scores + * @return combined score + */ + float combine(final float[] scores); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java new file mode 100644 index 000000000..35e097f7f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Collection of utility methods for score combination technique classes + */ +class ScoreCombinationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + + /** + * Get collection of weights based on user provided config + * @param params map of named parameters and their values + * @return collection of weights + */ + public List getWeights(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return List.of(); + } + // get weights, we don't need to check for instance as it's done during validation + return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + .map(Double::floatValue) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + supportedParams.stream().collect(Collectors.joining(",")) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } + + /** + * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise + * @param weights collection of weights for sub-queries + * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query + * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default + */ + public float getWeightForSubQuery(final List weights, final int indexOfSubQuery) { + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java new file mode 100644 index 000000000..67e776d77 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts combination of scores in query search results. + */ +@Log4j2 +public class ScoreCombiner { + + private static final Float ZERO_SCORE = 0.0f; + + /** + * Performs score combination based on input combination technique. Mutates input object by updating combined scores + * Main steps we're doing for combination: + * - create map of normalized scores per doc id + * - using normalized scores create another map of combined scores per doc id + * - count max number of hits among sub-queries + * - sort documents by scores and take first "max number" of docs + * - update query search results with normalized scores + * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", + * other steps are same for all techniques. + * @param queryTopDocs query results that need to be normalized, mutated by method execution + * @param scoreCombinationTechnique exact combination method that should be applied + */ + public void combineScores(final List queryTopDocs, final ScoreCombinationTechnique scoreCombinationTechnique) { + // iterate over results from each shard. Every CompoundTopDocs object has results from + // multiple sub queries, doc ids may repeat for each sub query results + queryTopDocs.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs)); + } + + private void combineShardScores(final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) { + return; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + // - create map of normalized scores results returned from the single shard + Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery); + + // - create map of combined scores per doc id + Map combinedNormalizedScoresByDocId = combineScoresAndGetCombinedNormalizedScoresPerDocument( + normalizedScoresPerDoc, + scoreCombinationTechnique + ); + + // - sort documents by scores and take first "max number" of docs + // create a collection of doc ids that are sorted by their combined scores + List sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + + // - update query search results with normalized scores + updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, sortedDocsIds); + } + + private List getSortedDocIds(final Map combinedNormalizedScoresByDocId) { + // we're merging docs with normalized and combined scores. we need to have only maxHits results + List sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet()); + sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))); + return sortedDocsIds; + } + + private ScoreDoc[] getCombinedScoreDocs( + final CompoundTopDocs compoundQueryTopDocs, + final Map combinedNormalizedScoresByDocId, + final List sortedScores, + final int maxHits + ) { + ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits]; + + int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex; + for (int j = 0; j < maxHits && j < sortedScores.size(); j++) { + int docId = sortedScores.get(j); + finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId); + } + return finalScoreDocs; + } + + public Map getNormalizedScoresPerDocument(final List topDocsPerSubQuery) { + Map normalizedScoresPerDoc = new HashMap<>(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs topDocs = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> { + float[] scores = new float[topDocsPerSubQuery.size()]; + // we initialize with -1.0, as after normalization it's possible that score is 0.0 + Arrays.fill(scores, -1.0f); + return scores; + }); + normalizedScoresPerDoc.get(scoreDoc.doc)[j] = scoreDoc.score; + } + } + return normalizedScoresPerDoc; + } + + private Map combineScoresAndGetCombinedNormalizedScoresPerDocument( + final Map normalizedScoresPerDocument, + final ScoreCombinationTechnique scoreCombinationTechnique + ) { + return normalizedScoresPerDocument.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + } + + private void updateQueryTopDocsWithCombinedScores( + final CompoundTopDocs compoundQueryTopDocs, + final List topDocsPerSubQuery, + final Map combinedNormalizedScoresByDocId, + final List sortedScores + ) { + // - count max number of hits among sub-queries + int maxHits = getMaxHits(topDocsPerSubQuery); + // - update query search results with normalized scores + compoundQueryTopDocs.scoreDocs = getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits); + compoundQueryTopDocs.totalHits = getTotalHits(topDocsPerSubQuery, maxHits); + } + + protected int getMaxHits(final List topDocsPerSubQuery) { + int maxHits = 0; + for (TopDocs topDocs : topDocsPerSubQuery) { + int hits = topDocs.scoreDocs.length; + maxHits = Math.max(maxHits, hits); + } + return maxHits; + } + + private TotalHits getTotalHits(final List topDocsPerSubQuery, int maxHits) { + TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO; + if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) { + totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + } + return new TotalHits(maxHits, totalHits); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java new file mode 100644 index 000000000..71412e736 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; + +import java.util.Map; +import java.util.Objects; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +/** + * Factory for query results normalization processor for search pipeline. Instantiates processor based on user provided input. + */ +@AllArgsConstructor +@Log4j2 +public class NormalizationProcessorFactory implements Processor.Factory { + public static final String NORMALIZATION_CLAUSE = "normalization"; + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + public static final String PARAMETERS = "parameters"; + + private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private ScoreNormalizationFactory scoreNormalizationFactory; + private ScoreCombinationFactory scoreCombinationFactory; + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + Map normalizationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE); + ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD; + if (Objects.nonNull(normalizationClause)) { + String normalizationTechniqueName = readStringProperty( + NormalizationProcessor.TYPE, + tag, + normalizationClause, + TECHNIQUE, + MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME + ); + normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName); + } + + Map combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE); + + ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD; + if (Objects.nonNull(combinationClause)) { + String combinationTechnique = readStringProperty( + NormalizationProcessor.TYPE, + tag, + combinationClause, + TECHNIQUE, + ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME + ); + // check for optional combination params + Map combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams); + } + log.info( + "Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]", + NormalizationProcessor.TYPE, + normalizationTechnique, + scoreCombinationTechnique + ); + return new NormalizationProcessor( + tag, + description, + normalizationTechnique, + scoreCombinationTechnique, + normalizationProcessorWorkflow + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java new file mode 100644 index 000000000..0e55e7231 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import lombok.ToString; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores based on L2 method + */ +@ToString(onlyExplicitlyIncluded = true) +public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "l2"; + private static final float MIN_SCORE = 0.001f; + + /** + * L2 normalization method. + * n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2) + * Main algorithm steps: + * - calculate sum of squares of all scores + * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query + */ + @Override + public void normalize(final List queryTopDocs) { + // get l2 norms for each sub-query + List normsPerSubquery = getL2Norm(queryTopDocs); + + // do normalization using actual score and l2 norm + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); + } + } + } + } + + private List getL2Norm(final List queryTopDocs) { + // find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries, + // or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for + // rest of sub-queries with zero total hits + int numOfSubqueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .findAny() + .get() + .getCompoundTopDocs() + .size(); + float[] l2Norms = new float[numOfSubqueries]; + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + int bound = topDocsPerSubQuery.size(); + for (int index = 0; index < bound; index++) { + for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) { + l2Norms[index] += scoreDocs.score * scoreDocs.score; + } + } + } + for (int index = 0; index < l2Norms.length; index++) { + l2Norms[index] = (float) Math.sqrt(l2Norms[index]); + } + List l2NormList = new ArrayList<>(); + for (int index = 0; index < numOfSubqueries; index++) { + l2NormList.add(l2Norms[index]); + } + return l2NormList; + } + + private float normalizeSingleScore(final float score, final float l2Norm) { + return l2Norm == 0 ? MIN_SCORE : score / l2Norm; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java new file mode 100644 index 000000000..e32dbb033 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import lombok.ToString; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +import com.google.common.primitives.Floats; + +/** + * Abstracts normalization of scores based on min-max method + */ +@ToString(onlyExplicitlyIncluded = true) +public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "min_max"; + private static final float MIN_SCORE = 0.001f; + private static final float SINGLE_RESULT_SCORE = 1.0f; + + /** + * Min-max normalization method. + * nscore = (score - min_score)/(max_score - min_score) + * Main algorithm steps: + * - calculate min and max scores for each sub query + * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query + */ + @Override + public void normalize(final List queryTopDocs) { + int numOfSubqueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .findAny() + .get() + .getCompoundTopDocs() + .size(); + // get min scores for each sub query + float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); + + // get max scores for each sub query + float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); + + // do normalization using actual score and min and max scores for corresponding sub query + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); + } + } + } + } + + private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { + float[] maxScores = new float[numOfSubqueries]; + Arrays.fill(maxScores, Float.MIN_VALUE); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + maxScores[j] = Math.max( + maxScores[j], + Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .orElse(Float.MIN_VALUE) + ); + } + } + return maxScores; + } + + private float[] getMinScores(final List queryTopDocs, final int numOfScores) { + float[] minScores = new float[numOfScores]; + Arrays.fill(minScores, Float.MAX_VALUE); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + minScores[j] = Math.min( + minScores[j], + Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .min(Float::compare) + .orElse(Float.MAX_VALUE) + ); + } + } + return minScores; + } + + private float normalizeSingleScore(final float score, final float minScore, final float maxScore) { + // edge case when there is only one score and min and max scores are same + if (Floats.compare(maxScore, minScore) == 0 && Floats.compare(maxScore, score) == 0) { + return SINGLE_RESULT_SCORE; + } + float normalizedScore = (score - minScore) / (maxScore - minScore); + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java new file mode 100644 index 000000000..667c237c7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Map; +import java.util.Optional; + +/** + * Abstracts creation of exact score normalization method based on technique name + */ +public class ScoreNormalizationFactory { + + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); + + private final Map scoreNormalizationMethodsMap = Map.of( + MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, + new MinMaxScoreNormalizationTechnique(), + L2ScoreNormalizationTechnique.TECHNIQUE_NAME, + new L2ScoreNormalizationTechnique() + ); + + /** + * Get score normalization method by technique name + * @param technique name of technique + * @return instance of ScoreNormalizationMethod for technique name + */ + public ScoreNormalizationTechnique createNormalization(final String technique) { + return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java new file mode 100644 index 000000000..fdaeb85d8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.List; + +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores in query search results. + */ +public interface ScoreNormalizationTechnique { + + /** + * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + */ + void normalize(final List queryTopDocs); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java new file mode 100644 index 000000000..5b8b7b1ca --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.List; +import java.util.Objects; + +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +public class ScoreNormalizer { + + /** + * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * @param scoreNormalizationTechnique exact normalization technique that should be applied + */ + public void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique scoreNormalizationTechnique) { + if (canQueryResultsBeNormalized(queryTopDocs)) { + scoreNormalizationTechnique.normalize(queryTopDocs); + } + } + + private boolean canQueryResultsBeNormalized(final List queryTopDocs) { + return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getCompoundTopDocs().size() > 0); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java new file mode 100644 index 000000000..2c79c56e5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; + +/** + * Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual + * scores for each sub-query. + */ +public final class HybridQuery extends Query implements Iterable { + + private final List subQueries; + + public HybridQuery(Collection subQueries) { + Objects.requireNonNull(subQueries, "collection of queries must not be null"); + if (subQueries.isEmpty()) { + throw new IllegalArgumentException("collection of queries must not be empty"); + } + this.subQueries = new ArrayList<>(subQueries); + } + + /** + * Returns an iterator over sub-queries that are parts of this hybrid query + * @return iterator + */ + @Override + public Iterator iterator() { + return getSubQueries().iterator(); + } + + /** + * Prints a query to a string, with field assumed to be the default field and omitted. + * @param field default field + * @return string representation of hybrid query + */ + @Override + public String toString(String field) { + StringBuilder buffer = new StringBuilder(); + buffer.append("("); + Iterator it = subQueries.iterator(); + for (int i = 0; it.hasNext(); i++) { + Query subquery = it.next(); + if (subquery instanceof BooleanQuery) { // wrap sub-boolean in parents + buffer.append("("); + buffer.append(subquery.toString(field)); + buffer.append(")"); + } else { + buffer.append(subquery.toString(field)); + } + if (i != subQueries.size() - 1) { + buffer.append(" | "); + } + } + buffer.append(")"); + return buffer.toString(); + } + + /** + * Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary, + * until the rewritten query is the same as the original query. + * @param reader + * @return + * @throws IOException + */ + @Override + public Query rewrite(IndexReader reader) throws IOException { + if (subQueries.isEmpty()) { + return new MatchNoDocsQuery("empty HybridQuery"); + } + + boolean actuallyRewritten = false; + List rewrittenSubQueries = new ArrayList<>(); + for (Query subQuery : subQueries) { + Query rewrittenSub = subQuery.rewrite(reader); + /* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls + queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses + perform better. For hybrid query we need to track progress of re-write for all sub-queries */ + actuallyRewritten |= rewrittenSub != subQuery; + rewrittenSubQueries.add(rewrittenSub); + } + + if (actuallyRewritten) { + return new HybridQuery(rewrittenSubQueries); + } + + return super.rewrite(reader); + } + + /** + * Recurse through the query tree, visiting all child queries and execute provided visitor. Part of multiple + * standard workflows, e.g. IndexSearcher.rewrite + * @param queryVisitor a QueryVisitor to be called by each query in the tree + */ + @Override + public void visit(QueryVisitor queryVisitor) { + QueryVisitor v = queryVisitor.getSubVisitor(BooleanClause.Occur.SHOULD, this); + for (Query q : subQueries) { + q.visit(v); + } + } + + /** + * Override and implement query instance equivalence properly in a subclass. This is required so that QueryCache works properly. + * @param other query object that when compare with this query object + * @return + */ + @Override + public boolean equals(Object other) { + return sameClassAs(other) && equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(HybridQuery other) { + return Objects.equals(subQueries, other.subQueries); + } + + /** + * Override and implement query hash code properly in a subclass. This is required so that QueryCache works properly. + * @return hash code of this object + */ + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + Objects.hashCode(subQueries); + return h; + } + + public Collection getSubQueries() { + return Collections.unmodifiableCollection(subQueries); + } + + /** + * Create the Weight used to score this query + * + * @param searcher + * @param scoreMode How the produced scorers will be consumed. + * @param boost The boost that is propagated by the parent queries. + * @return + * @throws IOException + */ + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new HybridQueryWeight(this, searcher, scoreMode, boost); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java new file mode 100644 index 000000000..19540550d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -0,0 +1,299 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.stream.Collectors; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.lucene.search.Query; +import org.opensearch.common.lucene.search.Queries; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.index.query.Rewriteable; + +/** + * Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and + * collects score for each of those sub-query. + */ +@Log4j2 +@Getter +@Setter +@Accessors(chain = true, fluent = true) +@NoArgsConstructor +public final class HybridQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "hybrid"; + + private static final ParseField QUERIES_FIELD = new ParseField("queries"); + + private final List queries = new ArrayList<>(); + + private String fieldName; + + private static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + + public HybridQueryBuilder(StreamInput in) throws IOException { + super(in); + queries.addAll(readQueries(in)); + } + + /** + * Serialize this query object into input stream + * @param out stream that we'll be used for serialization + * @throws IOException + */ + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + writeQueries(out, queries); + } + + /** + * Add one sub-query + * @param queryBuilder + * @return + */ + public HybridQueryBuilder add(QueryBuilder queryBuilder) { + if (queryBuilder == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "inner %s query clause cannot be null", NAME)); + } + queries.add(queryBuilder); + return this; + } + + /** + * Create builder object with a content of this hybrid query + * @param builder + * @param params + * @throws IOException + */ + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.startArray(QUERIES_FIELD.getPreferredName()); + for (QueryBuilder queryBuilder : queries) { + queryBuilder.toXContent(builder, params); + } + builder.endArray(); + printBoostAndQueryName(builder); + builder.endObject(); + } + + /** + * Create query object for current hybrid query using shard context + * @param queryShardContext context object that used to create hybrid query + * @return hybrid query object + * @throws IOException + */ + @Override + protected Query doToQuery(QueryShardContext queryShardContext) throws IOException { + Collection queryCollection = toQueries(queries, queryShardContext); + if (queryCollection.isEmpty()) { + return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME)); + } + return new HybridQuery(queryCollection); + } + + /** + * Creates HybridQueryBuilder from xContent. + * Example of a json for Hybrid Query: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 10 + * } + * } + * }, + * { + * "term": { + * "text": "keyword" + * } + * } + * ] + * } + * } + * } + * + * @param parser parser that has been initialized with the query content + * @return new instance of HybridQueryBuilder + * @throws IOException + */ + public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { + float boost = AbstractQueryBuilder.DEFAULT_BOOST; + + final List queries = new ArrayList<>(); + String queryName = null; + + String currentFieldName = null; + XContentParser.Token token; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + queries.add(parseInnerQueryBuilder(parser)); + } else { + log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "Field is not supported by [%s] query", NAME) + ); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + while (token != XContentParser.Token.END_ARRAY) { + if (queries.size() == MAX_NUMBER_OF_SUB_QUERIES) { + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "Number of sub-queries exceeds maximum supported by [%s] query", NAME) + ); + } + queries.add(parseInnerQueryBuilder(parser)); + token = parser.nextToken(); + } + } else { + log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "Field is not supported by [%s] query", NAME) + ); + } + } else { + if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + boost = parser.floatValue(); + // regular boost functionality is not supported, user should use score normalization methods to manipulate with scores + if (boost != DEFAULT_BOOST) { + log.error("[{}] query does not support provided value {} for [{}]", NAME, boost, BOOST_FIELD); + throw new ParsingException(parser.getTokenLocation(), "[{}] query does not support [{}]", NAME, BOOST_FIELD); + } + } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + queryName = parser.text(); + } else { + log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "Field is not supported by [%s] query", NAME) + ); + } + } + } + + if (queries.isEmpty()) { + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "[%s] requires 'queries' field with at least one clause", NAME) + ); + } + + HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder(); + compoundQueryBuilder.queryName(queryName); + compoundQueryBuilder.boost(boost); + for (QueryBuilder query : queries) { + compoundQueryBuilder.add(query); + } + return compoundQueryBuilder; + } + + protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException { + HybridQueryBuilder newBuilder = new HybridQueryBuilder(); + boolean changed = false; + for (QueryBuilder query : queries) { + QueryBuilder result = query.rewrite(queryShardContext); + if (result != query) { + changed = true; + } + newBuilder.add(result); + } + if (changed) { + newBuilder.queryName(queryName); + newBuilder.boost(boost); + return newBuilder; + } else { + return this; + } + } + + /** + * Indicates whether some other QueryBuilder object of the same type is "equal to" this one. + * @param obj + * @return true if objects are equal + */ + @Override + protected boolean doEquals(HybridQueryBuilder obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(fieldName, obj.fieldName); + equalsBuilder.append(queries, obj.queries); + return equalsBuilder.isEquals(); + } + + /** + * Create hash code for current hybrid query builder object + * @return hash code + */ + @Override + protected int doHashCode() { + return Objects.hash(queries); + } + + /** + * Returns the name of the writeable object + * @return + */ + @Override + public String getWriteableName() { + return NAME; + } + + private List readQueries(StreamInput in) throws IOException { + return in.readNamedWriteableList(QueryBuilder.class); + } + + private void writeQueries(StreamOutput out, List queries) throws IOException { + out.writeNamedWriteableList(queries); + } + + private Collection toQueries(Collection queryBuilders, QueryShardContext context) throws QueryShardException { + List queries = queryBuilders.stream().map(qb -> { + try { + return Rewriteable.rewrite(qb, context).toQuery(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).filter(Objects::nonNull).collect(Collectors.toList()); + return queries; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java new file mode 100644 index 000000000..42f0a56e6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import lombok.Getter; + +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DisjunctionDISIApproximation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +/** + * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing + * order of doc id, this class fills up array of scores per sub-query for each doc id. Order in array of scores + * corresponds to order of sub-queries in an input Hybrid query. + */ +public final class HybridQueryScorer extends Scorer { + + // score for each of sub-query in this hybrid query + @Getter + private final List subScorers; + + private final DisiPriorityQueue subScorersPQ; + + private final float[] subScores; + + private final Map queryToIndex; + + public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + super(weight); + this.subScorers = Collections.unmodifiableList(subScorers); + subScores = new float[subScorers.size()]; + this.queryToIndex = mapQueryToIndex(); + this.subScorersPQ = initializeSubScorersPQ(); + } + + /** + * Returns the score of the current document matching the query. Score is a sum of all scores from sub-query scorers. + * @return combined total score of all sub-scores + * @throws IOException + */ + @Override + public float score() throws IOException { + DisiWrapper topList = subScorersPQ.topList(); + float totalScore = 0.0f; + for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { + // check if this doc has match in the subQuery. If not, add score as 0.0 and continue + if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + continue; + } + totalScore += disiWrapper.scorer.score(); + } + return totalScore; + } + + /** + * Return a DocIdSetIterator over matching documents. + * @return DocIdSetIterator object + */ + @Override + public DocIdSetIterator iterator() { + return new DisjunctionDISIApproximation(this.subScorersPQ); + } + + /** + * Return the maximum score that documents between the last target that this iterator was shallow-advanced to included and upTo included. + * @param upTo upper limit for document id + * @return max score + * @throws IOException + */ + @Override + public float getMaxScore(int upTo) throws IOException { + return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> { + try { + return scorer.getMaxScore(upTo); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).max(Float::compare).orElse(0.0f); + } + + /** + * Returns the doc ID that is currently being scored. + * @return document id + */ + @Override + public int docID() { + return subScorersPQ.top().doc; + } + + /** + * Return array of scores per sub-query for doc id that is defined by current iterator position + * @return + * @throws IOException + */ + public float[] hybridScores() throws IOException { + float[] scores = new float[subScores.length]; + DisiWrapper topList = subScorersPQ.topList(); + for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { + // check if this doc has match in the subQuery. If not, add score as 0.0 and continue + if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + continue; + } + float subScore = disiWrapper.scorer.score(); + scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore; + } + return scores; + } + + private Map mapQueryToIndex() { + Map queryToIndex = new HashMap<>(); + int idx = 0; + for (Scorer scorer : subScorers) { + if (scorer == null) { + idx++; + continue; + } + queryToIndex.put(scorer.getWeight().getQuery(), idx); + idx++; + } + return queryToIndex; + } + + private DisiPriorityQueue initializeSubScorersPQ() { + Objects.requireNonNull(queryToIndex, "should not be null"); + Objects.requireNonNull(subScorers, "should not be null"); + DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(queryToIndex.size()); + for (Scorer scorer : subScorers) { + if (scorer == null) { + continue; + } + final DisiWrapper w = new DisiWrapper(scorer); + subScorersPQ.add(w); + } + return subScorersPQ; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java new file mode 100644 index 000000000..605892ea0 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Matches; +import org.apache.lucene.search.MatchesUtils; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +/** + * Calculates query weights and build query scorers for hybrid query. + */ +public final class HybridQueryWeight extends Weight { + + private final HybridQuery queries; + // The Weights for our subqueries, in 1-1 correspondence + private final List weights; + + private final ScoreMode scoreMode; + + /** + * Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. + */ + public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + super(hybridQuery); + this.queries = hybridQuery; + weights = hybridQuery.getSubQueries().stream().map(q -> { + try { + return searcher.createWeight(q, scoreMode, boost); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).collect(Collectors.toList()); + this.scoreMode = scoreMode; + } + + /** + * Returns Matches for a specific document, or null if the document does not match the parent query + * + * @param context the reader's context to create the {@link Matches} for + * @param doc the document's id relative to the given context's reader + * @return + * @throws IOException + */ + @Override + public Matches matches(LeafReaderContext context, int doc) throws IOException { + List mis = weights.stream().map(weight -> { + try { + return weight.matches(context, doc); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).filter(Objects::nonNull).collect(Collectors.toList()); + return MatchesUtils.fromSubMatches(mis); + } + + /** + * Create the scorer used to score our associated Query + * + * @param context the {@link LeafReaderContext} for which to return the + * {@link Scorer}. + * @return scorer of hybrid query that contains scorers of each sub-query, null if there are no matches in any sub-query + * @throws IOException + */ + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + List scorers = weights.stream().map(w -> { + try { + return w.scorer(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).collect(Collectors.toList()); + // if there are no matches in any of the scorers (sub-queries) we need to return + // scorer as null to avoid problems with disi result iterators + if (scorers.stream().allMatch(Objects::isNull)) { + return null; + } + return new HybridQueryScorer(this, scorers); + } + + /** + * Check if weight object can be cached + * + * @param ctx + * @return true if the object can be cached against a given leaf + */ + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return weights.stream().allMatch(w -> w.isCacheable(ctx)); + } + + /** + * Explain is not supported for hybrid query + * + * @param context the readers context to create the {@link Explanation} for. + * @param doc the document's id relative to the given context's reader + * @return + * @throws IOException + */ + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + throw new UnsupportedOperationException("Explain is not supported"); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java new file mode 100644 index 000000000..fbc820d8b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search; + +import java.util.Arrays; +import java.util.List; + +import lombok.Getter; +import lombok.ToString; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * Class stores collection of TodDocs for each sub query from hybrid query + */ +@ToString(includeFieldNames = true) +public class CompoundTopDocs extends TopDocs { + + @Getter + private List compoundTopDocs; + + public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) { + super(totalHits, scoreDocs); + } + + public CompoundTopDocs(TotalHits totalHits, List docs) { + // we pass clone of score docs from the sub-query that has most hits + super(totalHits, cloneLargestScoreDocs(docs)); + this.compoundTopDocs = docs; + } + + private static ScoreDoc[] cloneLargestScoreDocs(List docs) { + if (docs == null) { + return null; + } + ScoreDoc[] maxScoreDocs = new ScoreDoc[0]; + int maxLength = -1; + for (TopDocs topDoc : docs) { + if (topDoc == null || topDoc.scoreDocs == null) { + continue; + } + if (topDoc.scoreDocs.length > maxLength) { + maxLength = topDoc.scoreDocs.length; + maxScoreDocs = topDoc.scoreDocs; + } + } + // do deep copy + return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java new file mode 100644 index 000000000..4e5c77a5b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search; + +import java.util.Locale; + +import lombok.Getter; + +import org.apache.lucene.search.ScoreMode; + +/** + * Abstracts algorithm that allows early termination for the search flow if number of hits reached + * certain treshold + */ +public class HitsThresholdChecker { + private int hitCount; + @Getter + private final int totalHitsThreshold; + + public HitsThresholdChecker(int totalHitsThreshold) { + if (totalHitsThreshold < 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be >= 0, got %d", totalHitsThreshold)); + } + if (totalHitsThreshold == Integer.MAX_VALUE) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be less than max integer value")); + } + this.totalHitsThreshold = totalHitsThreshold; + } + + protected void incrementHitCount() { + ++hitCount; + } + + protected boolean isThresholdReached() { + return hitCount >= getTotalHitsThreshold(); + } + + protected ScoreMode scoreMode() { + return ScoreMode.TOP_SCORES; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java new file mode 100644 index 000000000..3fa413826 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.PriorityQueue; +import org.opensearch.neuralsearch.query.HybridQueryScorer; + +/** + * Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results + */ +@Log4j2 +public class HybridTopScoreDocCollector implements Collector { + private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + private int docBase; + private final HitsThresholdChecker hitsThresholdChecker; + private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; + private int[] totalHits; + private final int numOfHits; + @Getter + private PriorityQueue[] compoundScores; + + public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) { + numOfHits = numHits; + this.hitsThresholdChecker = hitsThresholdChecker; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + docBase = context.docBase; + + return new TopScoreDocCollector.ScorerLeafCollector() { + HybridQueryScorer compoundQueryScorer; + + @Override + public void setScorer(Scorable scorer) throws IOException { + super.setScorer(scorer); + compoundQueryScorer = (HybridQueryScorer) scorer; + } + + @Override + public void collect(int doc) throws IOException { + float[] subScoresByQuery = compoundQueryScorer.hybridScores(); + // iterate over results for each query + if (compoundScores == null) { + compoundScores = new PriorityQueue[subScoresByQuery.length]; + for (int i = 0; i < subScoresByQuery.length; i++) { + compoundScores[i] = new HitQueue(numOfHits, true); + } + totalHits = new int[subScoresByQuery.length]; + } + for (int i = 0; i < subScoresByQuery.length; i++) { + float score = subScoresByQuery[i]; + // if score is 0.0 there is no hits for that sub-query + if (score == 0) { + continue; + } + totalHits[i]++; + PriorityQueue pq = compoundScores[i]; + ScoreDoc topDoc = pq.top(); + topDoc.doc = doc + docBase; + topDoc.score = score; + pq.updateTop(); + } + } + }; + } + + @Override + public ScoreMode scoreMode() { + return hitsThresholdChecker.scoreMode(); + } + + /** + * Get resulting collection of TopDocs for hybrid query after we ran search for each of its sub query + * @return + */ + public List topDocs() { + if (compoundScores == null) { + return new ArrayList<>(); + } + final List topDocs = IntStream.range(0, compoundScores.length) + .mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i])) + .collect(Collectors.toList()); + return topDocs; + } + + private TopDocs topDocsPerQuery(int start, int howMany, PriorityQueue pq, int totalHits) { + if (howMany < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Number of hits requested must be greater than 0 but value was %d", howMany) + ); + } + + if (start < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Expected value of starting position is between 0 and %d, got %d", howMany, start) + ); + } + + if (start >= howMany || howMany == 0) { + return EMPTY_TOPDOCS; + } + + int size = howMany - start; + ScoreDoc[] results = new ScoreDoc[size]; + // pq's pop() returns the 'least' element in the queue, therefore need + // to discard the first ones, until we reach the requested range. + for (int i = pq.size() - start - size; i > 0; i--) { + pq.pop(); + } + + // Get the requested results from pq. + populateResults(results, size, pq); + + return new TopDocs(new TotalHits(totalHits, totalHitsRelation), results); + } + + protected void populateResults(ScoreDoc[] results, int howMany, PriorityQueue pq) { + for (int i = howMany - 1; i >= 0 && pq.size() > 0; i--) { + // adding to array if index is within [0..array_length - 1] + if (i < results.length) { + results[i] = pq.pop(); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java new file mode 100644 index 000000000..81b6b7ebd --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.query; + +import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.TopDocsCollectorContext; +import org.opensearch.search.rescore.RescoreContext; +import org.opensearch.search.sort.SortAndFormats; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the + * upstream standard implementation of searcher is called. + */ +@Log4j2 +public class HybridQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher { + + public boolean searchWith( + final SearchContext searchContext, + final ContextIndexSearcher searcher, + final Query query, + final LinkedList collectors, + final boolean hasFilterCollector, + final boolean hasTimeout + ) throws IOException { + if (query instanceof HybridQuery) { + return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } + return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } + + @VisibleForTesting + protected boolean searchWithCollector( + final SearchContext searchContext, + final ContextIndexSearcher searcher, + final Query query, + final LinkedList collectors, + final boolean hasFilterCollector, + final boolean hasTimeout + ) throws IOException { + log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + + final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); + collectors.addFirst(topDocsFactory); + if (searchContext.size() == 0) { + final TotalHitCountCollector collector = new TotalHitCountCollector(); + searcher.search(query, collector); + return false; + } + final IndexReader reader = searchContext.searcher().getIndexReader(); + int totalNumDocs = Math.max(0, reader.numDocs()); + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + final boolean shouldRescore = !searchContext.rescore().isEmpty(); + if (shouldRescore) { + for (RescoreContext rescoreContext : searchContext.rescore()) { + numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); + } + } + + final QuerySearchResult queryResult = searchContext.queryResult(); + + final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) + ); + + searcher.search(query, collector); + + if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { + queryResult.terminatedEarly(false); + } + + setTopDocsInQueryResult(queryResult, collector, searchContext); + + return shouldRescore; + } + + private void setTopDocsInQueryResult( + final QuerySearchResult queryResult, + final HybridTopScoreDocCollector collector, + final SearchContext searchContext + ) { + final List topDocs = collector.topDocs(); + final float maxScore = getMaxScore(topDocs); + final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs); + final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); + } + + private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs) { + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.size() == 0) { + return new TotalHits(0, relation); + } + long maxTotalHits = topDocs.get(0).totalHits.value; + for (TopDocs topDoc : topDocs) { + maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); + } + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.size() == 0) { + return Float.NaN; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java new file mode 100644 index 000000000..995f0c0fa --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.settings; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +import org.opensearch.common.settings.Setting; + +/** + * Class defines settings specific to neural-search plugin + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class NeuralSearchSettings { + + /** + * Gates the functionality of hybrid search + * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. + * Once that problem is resolved this feature flag can be removed. + */ + public static final Setting NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = Setting.boolSetting( + "plugins.neural_search.hybrid_search_enabled", + false, + Setting.Property.NodeScope + ); +} diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 257545132..ff221bf20 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -5,16 +5,29 @@ package org.opensearch.neuralsearch; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.opensearch.test.OpenSearchTestCase.randomFloat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; +import org.apache.commons.lang3.Range; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.query.QuerySearchResult; public class TestUtils { + private final static String RELATION_EQUAL_TO = "eq"; + /** * Convert an xContentBuilder to a map * @param xContentBuilder to produce map from @@ -51,4 +64,130 @@ public static float[] createRandomVector(int dimension) { } return vector; } + + /** + * Assert results of hyrdid query after score normalization and combination + * @param querySearchResults collection of query search results after they processed by normalization processor + */ + public static void assertQueryResultScores(List querySearchResults) { + assertNotNull(querySearchResults); + float maxScore = querySearchResults.stream() + .map(searchResult -> searchResult.topDocs().maxScore) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, maxScore, 0.0f); + float totalMaxScore = querySearchResults.stream() + .map(searchResult -> searchResult.getMaxScore()) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, totalMaxScore, 0.0f); + float maxScoreScoreFromScoreDocs = querySearchResults.stream() + .map( + searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .orElse(Float.MAX_VALUE) + ) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, maxScoreScoreFromScoreDocs, 0.0f); + float minScoreScoreFromScoreDocs = querySearchResults.stream() + .map( + searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .min(Float::compare) + .orElse(Float.MAX_VALUE) + ) + .min(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f); + } + + /** + * Assert results of hybrid query after score normalization and combination + * @param searchResponseWithWeightsAsMap collection of query search results after they processed by normalization processor + * @param expectedMaxScore expected maximum score + * @param expectedMaxMinusOneScore second highest expected score + * @param expectedMinScore expected minimal score + */ + public static void assertWeightedScores( + Map searchResponseWithWeightsAsMap, + double expectedMaxScore, + double expectedMaxMinusOneScore, + double expectedMinScore + ) { + assertNotNull(searchResponseWithWeightsAsMap); + Map totalWeights = getTotalHits(searchResponseWithWeightsAsMap); + assertNotNull(totalWeights.get("value")); + assertEquals(4, totalWeights.get("value")); + assertNotNull(totalWeights.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation")); + assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent()); + assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f); + + List scoresWeights = new ArrayList<>(); + for (Map oneHit : getNestedHits(searchResponseWithWeightsAsMap)) { + scoresWeights.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1))); + // verify the scores are normalized with inclusion of weights + assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001); + assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); + assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); + } + + /** + * Assert results of hybrid query after score normalization and combination + * @param searchResponseAsMap collection of query search results after they processed by normalization processor + * @param totalExpectedDocQty expected total document quantity + * @param minMaxScoreRange range of scores from min to max inclusive + */ + public static void assertHybridSearchResults( + Map searchResponseAsMap, + int totalExpectedDocQty, + float[] minMaxScoreRange + ) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range + assertTrue( + Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) + .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) + ); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } + + private static List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private static Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private static Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index ca435298d..3f49d9351 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -13,10 +13,13 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -57,11 +60,21 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + private static final String DEFAULT_USER_AGENT = "Kibana"; + protected static final String DEFAULT_NORMALIZATION_METHOD = "min_max"; + protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; + protected static final String PARAM_NAME_WEIGHTS = "weights"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @Before public void setupSettings() { + if (isUpdateClusterSettings()) { + updateClusterSettings(); + } + } + + protected void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); // default threshold for native circuit breaker is 90, it may be not enough on test runner machine updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); @@ -279,6 +292,27 @@ protected Map search(String index, QueryBuilder queryBuilder, in */ @SneakyThrows protected Map search(String index, QueryBuilder queryBuilder, QueryBuilder rescorer, int resultSize) { + return search(index, queryBuilder, rescorer, resultSize, Map.of()); + } + + /** + * Execute a search request initialized from a neural query builder that can add a rescore query to the request + * + * @param index Index to search against + * @param queryBuilder queryBuilder to produce source of query + * @param rescorer used for rescorer query builder + * @param resultSize number of results to return in the search + * @param requestParams additional request params for search + * @return Search results represented as a map + */ + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams + ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -292,6 +326,9 @@ protected Map search(String index, QueryBuilder queryBuilder, Qu Request request = new Request("POST", "/" + index + "/_search"); request.addParameter("size", Integer.toString(resultSize)); + if (requestParams != null && !requestParams.isEmpty()) { + requestParams.forEach(request::addParameter); + } request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); @@ -384,7 +421,12 @@ protected int getHitCount(Map searchResponseAsMap) { */ @SneakyThrows protected void prepareKnnIndex(String indexName, List knnFieldConfigs) { - createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs), ""); + prepareKnnIndex(indexName, knnFieldConfigs, 3); + } + + @SneakyThrows + protected void prepareKnnIndex(String indexName, List knnFieldConfigs, int numOfShards) { + createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs, numOfShards), ""); } /** @@ -419,11 +461,11 @@ protected boolean checkComplete(Map node) { } @SneakyThrows - private String buildIndexConfiguration(List knnFieldConfigs) { + private String buildIndexConfiguration(List knnFieldConfigs, int numberOfShards) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject("settings") - .field("number_of_shards", 3) + .field("number_of_shards", numberOfShards) .field("index.knn", true) .endObject() .startObject("mappings") @@ -516,4 +558,152 @@ private String registerModelGroup() throws IOException, URISyntaxException { assertNotNull(modelGroupId); return modelGroupId; } + + public boolean isUpdateClusterSettings() { + return true; + } + + @SneakyThrows + protected void deleteModel(String modelId) { + // need to undeploy first as model can be in use + makeRequest( + client(), + "POST", + String.format(LOCALE, "/_plugins/_ml/models/%s/_undeploy", modelId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + makeRequest( + client(), + "DELETE", + String.format(LOCALE, "/_plugins/_ml/models/%s", modelId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void createSearchPipelineWithResultsPostProcessor(final String pipelineId) { + createSearchPipeline(pipelineId, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of()); + } + + @SneakyThrows + protected void createSearchPipeline( + final String pipelineId, + final String normalizationMethod, + String combinationMethod, + final Map combinationParams + ) { + StringBuilder stringBuilderForContentBody = new StringBuilder(); + stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",") + .append("\"phase_results_processors\": [{ ") + .append("\"normalization-processor\": {") + .append("\"normalization\": {") + .append("\"technique\": \"%s\"") + .append("},") + .append("\"combination\": {") + .append("\"technique\": \"%s\""); + if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) { + stringBuilderForContentBody.append(", \"parameters\": {"); + if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) { + stringBuilderForContentBody.append("\"weights\": ").append(combinationParams.get(PARAM_NAME_WEIGHTS)); + } + stringBuilderForContentBody.append(" }"); + } + stringBuilderForContentBody.append("}").append("}}]}"); + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity(String.format(LOCALE, stringBuilderForContentBody.toString(), normalizationMethod, combinationMethod)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void createSearchPipelineWithDefaultResultsPostProcessor(final String pipelineId) { + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity( + String.format( + LOCALE, + "{\"description\": \"Post processor pipeline\"," + + "\"phase_results_processors\": [{ " + + "\"normalization-processor\": {}}]}" + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void deleteSearchPipeline(final String pipelineId) { + makeRequest( + client(), + "DELETE", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + /** + * Find all modesl that are currently deployed in the cluster + * @return set of model ids + */ + @SneakyThrows + protected Set findDeployedModels() { + + StringBuilder stringBuilderForContentBody = new StringBuilder(); + stringBuilderForContentBody.append("{") + .append("\"query\": { \"match_all\": {} },") + .append(" \"_source\": {") + .append(" \"includes\": [\"model_id\"],") + .append(" \"excludes\": [\"content\", \"model_content\"]") + .append("}}"); + + Response response = makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_search", + null, + toHttpEntity(stringBuilderForContentBody.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + + String responseBody = EntityUtils.toString(response.getEntity()); + + Map models = XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); + Set modelIds = new HashSet<>(); + if (Objects.isNull(models) || models.isEmpty()) { + return modelIds; + } + + Map hits = (Map) models.get("hits"); + List> innerHitsMap = (List>) hits.get("hits"); + return innerHitsMap.stream() + .map(hit -> (Map) hit.get("_source")) + .filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id")) + .map(hitsMap -> (String) hitsMap.get("model_id")) + .collect(Collectors.toSet()); + } + + /** + * Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model + * fail on assertion + * @return id of deployed model + */ + protected String getDeployedModelId() { + Set modelIds = findDeployedModels(); + assertEquals(1, modelIds.size()); + return modelIds.iterator().next(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java new file mode 100644 index 000000000..7918126c5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.plugin; + +import static org.mockito.Mockito.mock; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.plugins.SearchPipelinePlugin; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.query.QueryPhaseSearcher; + +public class NeuralSearchTests extends OpenSearchQueryTestCase { + + public void testQuerySpecs() { + NeuralSearch plugin = new NeuralSearch(); + List> querySpecs = plugin.getQueries(); + + assertNotNull(querySpecs); + assertFalse(querySpecs.isEmpty()); + assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName()))); + assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName()))); + } + + public void testQueryPhaseSearcher() { + NeuralSearch plugin = new NeuralSearch(); + Optional queryPhaseSearcher = plugin.getQueryPhaseSearcher(); + + assertNotNull(queryPhaseSearcher); + assertTrue(queryPhaseSearcher.isEmpty()); + + initFeatureFlags(); + + Optional queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher(); + + assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled); + assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty()); + assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher); + } + + public void testProcessors() { + NeuralSearch plugin = new NeuralSearch(); + Processor.Parameters processorParams = mock(Processor.Parameters.class); + Map processors = plugin.getProcessors(processorParams); + assertNotNull(processors); + assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); + } + + public void testSearchPhaseResultsProcessors() { + NeuralSearch plugin = new NeuralSearch(); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map> searchPhaseResultsProcessors = plugin + .getSearchPhaseResultsProcessors(parameters); + assertNotNull(searchPhaseResultsProcessors); + assertEquals(1, searchPhaseResultsProcessors.size()); + assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); + org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( + NormalizationProcessor.TYPE + ); + assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java new file mode 100644 index 000000000..a9b1fc9bf --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -0,0 +1,389 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.Range; +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class NormalizationProcessorIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT6 = "notexistingword"; + private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + + @Before + public void setUp() throws Exception { + super.setUp(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + findDeployedModels().forEach(this::deleteModel); + } + + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + String modelId = getDeployedModelId(); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 5, false); + } + + /** + * Using search pipelines with default result processor configs: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * } + * } + * ] + * } + */ + @SneakyThrows + public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); + String modelId = getDeployedModelId(); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 5, false); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + String modelId = getDeployedModelId(); + int totalExpectedDocQty = 6; + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 6, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 6, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(.75f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + + // verify the scores are normalized. we need special assert logic because combined score may vary as neural search query + // based on random vectors and return results for every doc. In some cases that may affect 1.0 score from term query and make it + // lower. + float highestScore = scores.stream().max(Double::compare).get().floatValue(); + assertTrue(Range.between(.75f, 1.0f).contains(highestScore)); + float lowestScore = scores.stream().min(Double::compare).get().floatValue(); + assertTrue(Range.between(.0f, .5f).contains(lowestScore)); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndNoMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 0, true); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 4, true); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } + + private List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } + + private void assertQueryResults(Map searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + if (totalExpectedDocQty > 0) { + assertEquals(1.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + } else { + assertEquals(0.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + } + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized + if (totalExpectedDocQty > 0) { + assertEquals(1.0, (double) scores.stream().max(Double::compare).get(), 0.001); + if (assertMinScore) { + assertEquals(0.001, (double) scores.stream().min(Double::compare).get(), 0.001); + } + } + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java new file mode 100644 index 000000000..397642007 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -0,0 +1,252 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchProgressListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.NoopCircuitBreaker; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class NormalizationProcessorTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String INDEX_NAME = "index1"; + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; + private SearchPhaseController searchPhaseController; + private ThreadPool threadPool; + private OpenSearchThreadPoolExecutor executor; + + @Before + public void setup() { + searchPhaseController = new SearchPhaseController(writableRegistry(), s -> new InternalAggregation.ReduceContextBuilder() { + @Override + public InternalAggregation.ReduceContext forPartialReduction() { + return InternalAggregation.ReduceContext.forPartialReduction( + BigArrays.NON_RECYCLING_INSTANCE, + null, + () -> PipelineAggregator.PipelineTree.EMPTY + ); + } + + public InternalAggregation.ReduceContext forFinalReduction() { + return InternalAggregation.ReduceContext.forFinalReduction( + BigArrays.NON_RECYCLING_INSTANCE, + null, + b -> {}, + PipelineAggregator.PipelineTree.EMPTY + ); + }; + }); + threadPool = new TestThreadPool(NormalizationProcessorTests.class.getName()); + executor = OpenSearchExecutors.newFixed( + "test", + 1, + 10, + OpenSearchExecutors.daemonThreadFactory("test"), + threadPool.getThreadContext() + ); + } + + @After + public void cleanup() { + executor.shutdownNow(); + terminate(threadPool); + } + + public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + assertEquals(DESCRIPTION, normalizationProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, normalizationProcessor.getTag()); + assertEquals(SearchPhaseName.FETCH, normalizationProcessor.getAfterPhase()); + assertEquals(SearchPhaseName.QUERY, normalizationProcessor.getBeforePhase()); + assertFalse(normalizationProcessor.isIgnoreFailure()); + } + + public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + CompoundTopDocs topDocs = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ) + ) + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + TestUtils.assertQueryResultScores(querySearchResults); + } + + public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkflow() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME), + normalizationProcessorWorkflow + ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(null, searchPhaseContext); + + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any()); + } + + public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java new file mode 100644 index 000000000..453725a0d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.Mockito.spy; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +public class NormalizationProcessorWorkflowTests extends OpenSearchTestCase { + + public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + CompoundTopDocs topDocs = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ) + ) + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + querySearchResults.add(querySearchResult); + } + + normalizationProcessorWorkflow.execute( + querySearchResults, + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScores(querySearchResults); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java new file mode 100644 index 000000000..e56532b52 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -0,0 +1,412 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreCombinationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + + @Before + public void setUp() throws Exception { + super.setUp(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + findDeployedModels().forEach(this::deleteModel); + } + + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "arithmetic_mean", + * "parameters": { + * "weights": [ + * 0.4, 0.7 + * ] + * } + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + // check case when number of weights and sub-queries are same + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseWithWeights1AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); + + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + ); + + Map searchResponseWithWeights2AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); + + // check case when number of weights is less than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + ); + + Map searchResponseWithWeights3AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); + + // check case when number of weights is more than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + ); + + Map searchResponseWithWeights4AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); + } + + /** + * Using search pipelines with config for harmonic mean: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "harmonic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapDefaultNorm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderDefaultNorm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertHybridSearchResults(searchResponseAsMapDefaultNorm, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapL2Norm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderL2Norm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapL2Norm, 5, new float[] { 0.5f, 1.0f }); + } + + /** + * Using search pipelines with config for geometric mean: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "geometric_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapDefaultNorm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderDefaultNorm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertHybridSearchResults(searchResponseAsMapDefaultNorm, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapL2Norm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderL2Norm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapL2Norm, 5, new float[] { 0.5f, 1.0f }); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..1a7f895cd --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.test.OpenSearchTestCase; + +public class ScoreCombinationTechniqueTests extends OpenSearchTestCase { + + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { + ScoreCombiner scoreCombiner = new ScoreCombiner(); + scoreCombiner.combineScores(List.of(), ScoreCombinationFactory.DEFAULT_METHOD); + } + + public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenScoresCombined() { + ScoreCombiner scoreCombiner = new ScoreCombiner(); + + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 1.0f), new ScoreDoc(2, .25f), new ScoreDoc(4, 0.001f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(5, 0.001f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) + ) + ) + ); + + scoreCombiner.combineScores(queryTopDocs, ScoreCombinationFactory.DEFAULT_METHOD); + + assertNotNull(queryTopDocs); + assertEquals(3, queryTopDocs.size()); + + assertEquals(3, queryTopDocs.get(0).scoreDocs.length); + assertEquals(1.0, queryTopDocs.get(0).scoreDocs[0].score, 0.001f); + assertEquals(1, queryTopDocs.get(0).scoreDocs[0].doc); + assertEquals(1.0, queryTopDocs.get(0).scoreDocs[1].score, 0.001f); + assertEquals(3, queryTopDocs.get(0).scoreDocs[1].doc); + assertEquals(0.25, queryTopDocs.get(0).scoreDocs[2].score, 0.001f); + assertEquals(2, queryTopDocs.get(0).scoreDocs[2].doc); + + assertEquals(4, queryTopDocs.get(1).scoreDocs.length); + assertEquals(0.9, queryTopDocs.get(1).scoreDocs[0].score, 0.001f); + assertEquals(2, queryTopDocs.get(1).scoreDocs[0].doc); + assertEquals(0.6, queryTopDocs.get(1).scoreDocs[1].score, 0.001f); + assertEquals(4, queryTopDocs.get(1).scoreDocs[1].doc); + assertEquals(0.5, queryTopDocs.get(1).scoreDocs[2].score, 0.001f); + assertEquals(7, queryTopDocs.get(1).scoreDocs[2].doc); + assertEquals(0.01, queryTopDocs.get(1).scoreDocs[3].score, 0.001f); + assertEquals(9, queryTopDocs.get(1).scoreDocs[3].doc); + + assertEquals(0, queryTopDocs.get(2).scoreDocs.length); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java new file mode 100644 index 000000000..64b3fe07f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -0,0 +1,295 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreNormalizationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + findDeployedModels().forEach(this::deleteModel); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + /** + * Using search pipelines with config for l2 norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); + hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapArithmeticMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderArithmeticMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 0.6f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); + hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapHarmonicMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderHarmonicMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapHarmonicMean, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); + hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapGeometricMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderGeometricMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapGeometricMean, 5, new float[] { 0.5f, 1.0f }); + } + + /** + * Using search pipelines with config for min-max norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); + hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapArithmeticMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderArithmeticMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 1.0f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); + hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapHarmonicMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderHarmonicMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapHarmonicMean, 5, new float[] { 0.6f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); + hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapGeometricMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderGeometricMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapGeometricMean, 5, new float[] { 0.6f, 1.0f }); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..6188a7ef5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.Range; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.test.OpenSearchTestCase; + +public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { + + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); + } + + @SneakyThrows + public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(1, resultDoc.getCompoundTopDocs().size()); + TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertEquals(1, topDoc.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); + assertNotNull(topDoc.scoreDocs); + assertEquals(1, topDoc.scoreDocs.length); + ScoreDoc scoreDoc = topDoc.scoreDocs[0]; + assertEquals(1.0, scoreDoc.score, 0.001f); + assertEquals(1, scoreDoc.doc); + } + + @SneakyThrows + public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ) + ) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(1, resultDoc.getCompoundTopDocs().size()); + TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertEquals(3, topDoc.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); + assertNotNull(topDoc.scoreDocs); + assertEquals(3, topDoc.scoreDocs.length); + ScoreDoc highScoreDoc = topDoc.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDoc.scoreDocs[topDoc.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + } + + public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } + ) + ) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(2, resultDoc.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocSubqueryOne = resultDoc.getCompoundTopDocs().get(0); + assertEquals(3, topDocSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); + assertNotNull(topDocSubqueryOne.scoreDocs); + assertEquals(3, topDocSubqueryOne.scoreDocs.length); + ScoreDoc highScoreDoc = topDocSubqueryOne.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDocSubqueryOne.scoreDocs[topDocSubqueryOne.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + // sub-query two + TopDocs topDocSubqueryTwo = resultDoc.getCompoundTopDocs().get(1); + assertEquals(2, topDocSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); + assertNotNull(topDocSubqueryTwo.scoreDocs); + assertEquals(2, topDocSubqueryTwo.scoreDocs.length); + assertEquals(1.0, topDocSubqueryTwo.scoreDocs[0].score, 0.001f); + assertEquals(3, topDocSubqueryTwo.scoreDocs[0].doc); + assertEquals(0.0, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].score, 0.001f); + assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); + } + + public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 2.2f), new ScoreDoc(4, 1.8f), new ScoreDoc(7, 0.9f), new ScoreDoc(9, 0.01f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) + ) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + assertNotNull(queryTopDocs); + assertEquals(3, queryTopDocs.size()); + // shard one + CompoundTopDocs resultDocShardOne = queryTopDocs.get(0); + assertEquals(2, resultDocShardOne.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocSubqueryOne = resultDocShardOne.getCompoundTopDocs().get(0); + assertEquals(3, topDocSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); + assertNotNull(topDocSubqueryOne.scoreDocs); + assertEquals(3, topDocSubqueryOne.scoreDocs.length); + ScoreDoc highScoreDoc = topDocSubqueryOne.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDocSubqueryOne.scoreDocs[topDocSubqueryOne.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + // sub-query two + TopDocs topDocSubqueryTwo = resultDocShardOne.getCompoundTopDocs().get(1); + assertEquals(2, topDocSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); + assertNotNull(topDocSubqueryTwo.scoreDocs); + assertEquals(2, topDocSubqueryTwo.scoreDocs.length); + assertTrue(Range.between(0.0f, 1.0f).contains(topDocSubqueryTwo.scoreDocs[0].score)); + assertEquals(3, topDocSubqueryTwo.scoreDocs[0].doc); + assertTrue(Range.between(0.0f, 1.0f).contains(topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].score)); + assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); + + // shard two + CompoundTopDocs resultDocShardTwo = queryTopDocs.get(1); + assertEquals(2, resultDocShardTwo.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getCompoundTopDocs().get(0); + assertEquals(0, topDocShardTwoSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryOne.totalHits.relation); + assertNotNull(topDocShardTwoSubqueryOne.scoreDocs); + assertEquals(0, topDocShardTwoSubqueryOne.scoreDocs.length); + // sub-query two + TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getCompoundTopDocs().get(1); + assertEquals(4, topDocShardTwoSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryTwo.totalHits.relation); + assertNotNull(topDocShardTwoSubqueryTwo.scoreDocs); + assertEquals(4, topDocShardTwoSubqueryTwo.scoreDocs.length); + assertEquals(1.0, topDocShardTwoSubqueryTwo.scoreDocs[0].score, 0.001f); + assertEquals(2, topDocShardTwoSubqueryTwo.scoreDocs[0].doc); + assertEquals(0.0, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].score, 0.001f); + assertEquals(9, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].doc); + + // shard three + CompoundTopDocs resultDocShardThree = queryTopDocs.get(2); + assertEquals(2, resultDocShardThree.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getCompoundTopDocs().get(0); + assertEquals(0, topDocShardThreeSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryOne.totalHits.relation); + assertEquals(0, topDocShardThreeSubqueryOne.scoreDocs.length); + // sub-query two + TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getCompoundTopDocs().get(1); + assertEquals(0, topDocShardThreeSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryTwo.totalHits.relation); + assertEquals(0, topDocShardThreeSubqueryTwo.scoreDocs.length); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 10d4779c8..b306a7c90 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -9,9 +9,12 @@ import java.nio.file.Path; import java.util.Map; +import lombok.SneakyThrows; + import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; import org.apache.http.util.EntityUtils; +import org.junit.After; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -25,6 +28,17 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private static final String PIPELINE_NAME = "pipeline-hybrid"; + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); + } + public void testTextEmbeddingProcessor() throws Exception { String modelId = uploadTextEmbeddingModel(); loadModel(modelId); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..842df736d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public ArithmeticMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = this::arithmeticMean; + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, -1.0f, 0.6f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expectedScore = 0.825f; + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float arithmeticMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float sumOfWeightedScores = 0; + float sumOfWeights = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i); + float weight = weights.get(i).floatValue(); + if (score >= 0) { + sumOfWeightedScores += score * weight; + sumOfWeights += weight; + } + } + return sumOfWeights == 0 ? 0f : sumOfWeightedScores / sumOfWeights; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..e8ad3532d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; + +import lombok.NoArgsConstructor; + +import org.apache.commons.lang.ArrayUtils; +import org.opensearch.test.OpenSearchTestCase; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +@NoArgsConstructor +public class BaseScoreCombinationTechniqueTests extends OpenSearchTestCase { + + protected BiFunction, List, Float> expectedScoreFunction; + protected static final int RANDOM_SCORES_SIZE = 100; + + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(final ScoreCombinationTechnique technique) { + float[] scores = { 1.0f, 0.5f, 0.3f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), List.of(1.0, 1.0, 1.0)); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(final ScoreCombinationTechnique technique) { + float[] scores = { 1.0f, -1.0f, 0.6f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), List.of(1.0, 1.0, 1.0)); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + List scores, + float expectedScore + ) { + float[] scoresArray = new float[scores.size()]; + for (int i = 0; i < scoresArray.length; i++) { + scoresArray[i] = scores.get(i); + } + float actualScore = technique.combine(scoresArray); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + final List weights + ) { + float[] scores = new float[weights.size()]; + for (int i = 0; i < RANDOM_SCORES_SIZE; i++) { + scores[i] = randomScore(); + } + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + List scores, + float expectedScore + ) { + float[] scoresArray = new float[scores.size()]; + for (int i = 0; i < scoresArray.length; i++) { + scoresArray[i] = scores.get(i); + } + float actualScore = technique.combine(scoresArray); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + final List weights + ) { + float[] scores = new float[weights.size()]; + for (int i = 0; i < RANDOM_SCORES_SIZE; i++) { + scores[i] = randomScore(); + } + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + private float randomScore() { + return RandomizedTest.randomBoolean() ? -1.0f : RandomizedTest.randomFloat(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..fe0d962ca --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public GeometricMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = this::geometricMean; + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, 0.5f, 0.3f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expectedScore = 0.5797f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, -1.0f, 0.6f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expectedScore = 0.7997f; + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + /** + * Verify score correctness by using alternative formula for geometric mean as n-th root of product of weighted scores, + * more details in here https://en.wikipedia.org/wiki/Weighted_geometric_mean + */ + private float geometricMean(List scores, List weights) { + float product = 1.0f; + float sumOfWeights = 0.0f; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.size(); indexOfSubQuery++) { + float score = scores.get(indexOfSubQuery); + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; + } + float weight = weights.get(indexOfSubQuery).floatValue(); + product *= Math.pow(score, weight); + sumOfWeights += weight; + } + return sumOfWeights == 0 ? 0f : (float) Math.pow(product, (float) 1 / sumOfWeights); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..8187123a1 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public HarmonicMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> harmonicMean(scores, weights); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, 0.5f, 0.3f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expecteScore = 0.4954f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, -1.0f, 0.6f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expectedScore = 0.7741f; + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float harmonicMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float w = 0, h = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i), weight = weights.get(i).floatValue(); + if (score > 0) { + w += weight; + h += weight / score; + } + } + return h == 0 ? 0f : w / h; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java new file mode 100644 index 000000000..ee46bf0a1 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.hamcrest.Matchers.containsString; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreCombinationFactoryTests extends OpenSearchQueryTestCase { + + public void testArithmeticWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("arithmetic_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof ArithmeticMeanScoreCombinationTechnique); + } + + public void testHarmonicWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("harmonic_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof HarmonicMeanScoreCombinationTechnique); + } + + public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("geometric_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); + } + + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + IllegalArgumentException illegalArgumentException = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationFactory.createCombination("randomname") + ); + org.hamcrest.MatcherAssert.assertThat( + illegalArgumentException.getMessage(), + containsString("provided combination technique is not supported") + ); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java new file mode 100644 index 000000000..83bb0e7bb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -0,0 +1,388 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.TECHNIQUE; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class NormalizationProcessorFactoryTests extends OpenSearchTestCase { + + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; + + @SneakyThrows + public void testNormalizationProcessor_whenNoParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenTechniqueNamesNotSet_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put("normalization", new HashMap<>(Map.of())); + config.put("combination", new HashMap<>(Map.of())); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenWithParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put("normalization", new HashMap<>(Map.of("technique", "min_max"))); + config.put("combination", new HashMap<>(Map.of("technique", "arithmetic_mean"))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); + config.put( + COMBINATION_CLAUSE, + new HashMap<>( + Map.of( + TECHNIQUE, + "arithmetic_mean", + PARAMETERS, + new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble()))) + ) + ) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + public void testInputValidation_whenInvalidNormalizationClause_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + + expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, "")), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)) + ) + ), + pipelineContext + ) + ); + + expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, "random_name_for_normalization")), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)) + ) + ), + pipelineContext + ) + ); + } + + public void testInputValidation_whenInvalidCombinationClause_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + + expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, "")) + ) + ), + pipelineContext + ) + ); + + expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, "random_name_for_combination")) + ) + ), + pipelineContext + ) + ); + } + + public void testInputValidation_whenInvalidCombinationParams_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + + IllegalArgumentException exceptionBadTechnique = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap( + Map.of( + TECHNIQUE, + "", + NormalizationProcessorFactory.PARAMETERS, + new HashMap<>(Map.of("weights", "random_string")) + ) + ) + ) + ), + pipelineContext + ) + ); + assertThat(exceptionBadTechnique.getMessage(), containsString("provided combination technique is not supported")); + + IllegalArgumentException exceptionInvalidWeights = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap( + Map.of( + TECHNIQUE, + COMBINATION_METHOD, + NormalizationProcessorFactory.PARAMETERS, + new HashMap<>(Map.of("weights", 5.0)) + ) + ) + ) + ), + pipelineContext + ) + ); + assertThat(exceptionInvalidWeights.getMessage(), containsString("parameter [weights] must be a collection of numbers")); + + IllegalArgumentException exceptionInvalidWeights2 = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap( + Map.of( + TECHNIQUE, + COMBINATION_METHOD, + NormalizationProcessorFactory.PARAMETERS, + new HashMap<>(Map.of("weights", new Boolean[] { true, false })) + ) + ) + ) + ), + pipelineContext + ) + ); + assertThat(exceptionInvalidWeights2.getMessage(), containsString("parameter [weights] must be a collection of numbers")); + + IllegalArgumentException exceptionInvalidParam = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap( + Map.of( + TECHNIQUE, + COMBINATION_METHOD, + NormalizationProcessorFactory.PARAMETERS, + new HashMap<>(Map.of("random_param", "value")) + ) + ) + ) + ), + pipelineContext + ) + ); + assertThat( + exceptionInvalidParam.getMessage(), + containsString("provided parameter for combination technique is not supported. supported parameters are [weights]") + ); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..c5f8c4860 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -0,0 +1,225 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Arrays; +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores based on min-max method + */ +public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); + Float[] scores = { 0.5f, 0.2f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, l2Norm(scores[0], Arrays.asList(scores))), + new ScoreDoc(4, l2Norm(scores[1], Arrays.asList(scores))) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + assertCompoundTopDocs(expectedCompoundDocs, compoundTopDocs.get(0).getCompoundTopDocs().get(0)); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); + Float[] scoresQuery1 = { 0.5f, 0.2f }; + Float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresQuery1[0]), new ScoreDoc(4, scoresQuery1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresQuery2[0]), + new ScoreDoc(4, scoresQuery2[1]), + new ScoreDoc(2, scoresQuery2[2]) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, l2Norm(scoresQuery1[0], Arrays.asList(scoresQuery1))), + new ScoreDoc(4, l2Norm(scoresQuery1[1], Arrays.asList(scoresQuery1))) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, l2Norm(scoresQuery2[0], Arrays.asList(scoresQuery2))), + new ScoreDoc(4, l2Norm(scoresQuery2[1], Arrays.asList(scoresQuery2))), + new ScoreDoc(2, l2Norm(scoresQuery2[2], Arrays.asList(scoresQuery2))) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getCompoundTopDocs().get(i), compoundTopDocs.get(0).getCompoundTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); + Float[] scoresShard1Query1 = { 0.5f, 0.2f }; + Float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; + Float[] scoresShard2Query2 = { 2.9f, 0.7f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresShard1Query1[0]), new ScoreDoc(4, scoresShard1Query1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[0]), + new ScoreDoc(4, scoresShard1and2Query3[1]), + new ScoreDoc(2, scoresShard1and2Query3[2]) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, scoresShard2Query2[0]), new ScoreDoc(9, scoresShard2Query2[1]) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[3]), + new ScoreDoc(9, scoresShard1and2Query3[4]), + new ScoreDoc(10, scoresShard1and2Query3[5]), + new ScoreDoc(15, scoresShard1and2Query3[6]) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, l2Norm(scoresShard1Query1[0], Arrays.asList(scoresShard1Query1))), + new ScoreDoc(4, l2Norm(scoresShard1Query1[1], Arrays.asList(scoresShard1Query1))) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, l2Norm(scoresShard1and2Query3[0], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(4, l2Norm(scoresShard1and2Query3[1], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(2, l2Norm(scoresShard1and2Query3[2], Arrays.asList(scoresShard1and2Query3))) } + ) + ) + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(7, l2Norm(scoresShard2Query2[0], Arrays.asList(scoresShard2Query2))), + new ScoreDoc(9, l2Norm(scoresShard2Query2[1], Arrays.asList(scoresShard2Query2))) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, l2Norm(scoresShard1and2Query3[3], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(9, l2Norm(scoresShard1and2Query3[4], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(10, l2Norm(scoresShard1and2Query3[5], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(15, l2Norm(scoresShard1and2Query3[6], Arrays.asList(scoresShard1and2Query3))) } + ) + ) + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs( + expectedCompoundDocsShard1.getCompoundTopDocs().get(i), + compoundTopDocs.get(0).getCompoundTopDocs().get(i) + ); + } + assertNotNull(compoundTopDocs.get(1).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs( + expectedCompoundDocsShard2.getCompoundTopDocs().get(i), + compoundTopDocs.get(1).getCompoundTopDocs().get(i) + ); + } + } + + private float l2Norm(float score, List scores) { + return score / (float) Math.sqrt(scores.stream().map(Float::doubleValue).map(s -> s * s).mapToDouble(Double::doubleValue).sum()); + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..c7de1fdb5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores based on min-max method + */ +public class MinMaxScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + MinMaxScoreNormalizationTechnique normalizationTechnique = new MinMaxScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.001f) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + assertCompoundTopDocs(expectedCompoundDocs, compoundTopDocs.get(0).getCompoundTopDocs().get(0)); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + MinMaxScoreNormalizationTechnique normalizationTechnique = new MinMaxScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.001f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getCompoundTopDocs().get(i), compoundTopDocs.get(0).getCompoundTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + MinMaxScoreNormalizationTechnique normalizationTechnique = new MinMaxScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.001f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } + ) + ) + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, 0.001f) } + ) + ) + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs( + expectedCompoundDocsShard1.getCompoundTopDocs().get(i), + compoundTopDocs.get(0).getCompoundTopDocs().get(i) + ); + } + assertNotNull(compoundTopDocs.get(1).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs( + expectedCompoundDocsShard2.getCompoundTopDocs().get(i), + compoundTopDocs.get(1).getCompoundTopDocs().get(i) + ); + } + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java new file mode 100644 index 000000000..bf1489fe3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import static org.hamcrest.Matchers.containsString; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreNormalizationFactoryTests extends OpenSearchQueryTestCase { + + public void testMinMaxNorm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("min_max"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof MinMaxScoreNormalizationTechnique); + } + + public void testL2Norm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("l2"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof L2ScoreNormalizationTechnique); + } + + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + IllegalArgumentException illegalArgumentException = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationFactory.createNormalization("randomname") + ); + assertThat(illegalArgumentException.getMessage(), containsString("provided normalization technique is not supported")); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java new file mode 100644 index 000000000..49d1ba974 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -0,0 +1,719 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; +import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; +import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import lombok.SneakyThrows; + +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.FilterStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.index.Index; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.query.KNNQueryBuilder; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryBuilderTests extends OpenSearchQueryTestCase { + static final String VECTOR_FIELD_NAME = "vectorField"; + static final String TEXT_FIELD_NAME = "field"; + static final String QUERY_TEXT = "Hello world!"; + static final String TERM_QUERY_TEXT = "keyword"; + static final String MODEL_ID = "mfgfgdsfgfdgsde"; + static final int K = 10; + static final float BOOST = 1.8f; + static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[4]; + static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + + @SneakyThrows + public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + Query queryNoSubQueries = queryBuilder.doToQuery(mockQueryShardContext); + assertTrue(queryNoSubQueries instanceof MatchNoDocsQuery); + } + + @SneakyThrows + public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + + queryBuilder.add(neuralQueryBuilder); + Query queryOnlyNeural = queryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(queryOnlyNeural); + assertTrue(queryOnlyNeural instanceof HybridQuery); + assertEquals(1, ((HybridQuery) queryOnlyNeural).getSubQueries().size()); + assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof KNNQuery); + KNNQuery knnQuery = (KNNQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next(); + assertEquals(VECTOR_FIELD_NAME, knnQuery.getField()); + assertEquals(K, knnQuery.getK()); + assertNotNull(knnQuery.getQueryVector()); + } + + @SneakyThrows + public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + + queryBuilder.add(neuralQueryBuilder); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + queryBuilder.add(termSubQuery); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + Query queryTwoSubQueries = queryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(queryTwoSubQueries); + assertTrue(queryTwoSubQueries instanceof HybridQuery); + assertEquals(2, ((HybridQuery) queryTwoSubQueries).getSubQueries().size()); + // verify knn vector query + Iterator queryIterator = ((HybridQuery) queryTwoSubQueries).getSubQueries().iterator(); + Query firstQuery = queryIterator.next(); + assertTrue(firstQuery instanceof KNNQuery); + KNNQuery knnQuery = (KNNQuery) firstQuery; + assertEquals(VECTOR_FIELD_NAME, knnQuery.getField()); + assertEquals(K, knnQuery.getK()); + assertNotNull(knnQuery.getQueryVector()); + // verify term query + Query secondQuery = queryIterator.next(); + assertTrue(secondQuery instanceof TermQuery); + TermQuery termQuery = (TermQuery) secondQuery; + assertEquals(TEXT_FIELD_NAME, termQuery.getTerm().field()); + assertEquals(TERM_QUERY_TEXT, termQuery.getTerm().text()); + } + + @SneakyThrows + public void testDoToQuery_whenTooManySubqueries_thenFail() { + // create query with 6 sub-queries, which is more than current max allowed + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + + ParsingException exception = expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + assertThat(exception.getMessage(), containsString("Number of sub-queries exceeds maximum supported")); + } + + /** + * Tests basic query: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * }, + * { + * "term": { + * "text": "keyword" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(NeuralQueryBuilder.NAME), + NeuralQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + + HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser); + assertEquals(2, queryTwoSubQueries.queries().size()); + assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder); + assertTrue(queryTwoSubQueries.queries().get(1) instanceof TermQueryBuilder); + // verify knn vector query + NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); + assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(K, neuralQueryBuilder.k()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(BOOST, neuralQueryBuilder.boost(), 0f); + // verify term query + TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryTwoSubQueries.queries().get(1); + assertEquals(TEXT_FIELD_NAME, termQueryBuilder.fieldName()); + assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value()); + } + + @SneakyThrows + public void testFromXContent_whenIncorrectFormat_thenFail() { + XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startArray("random_field") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .endArray() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(NeuralQueryBuilder.NAME), + NeuralQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + unsupportedFieldXContentBuilder.contentType().xContent(), + BytesReference.bytes(unsupportedFieldXContentBuilder) + ); + contentParser.nextToken(); + + expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + + XContentBuilder emptySubQueriesXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .endArray() + .endObject(); + + XContentParser contentParser2 = createParser( + namedXContentRegistry, + unsupportedFieldXContentBuilder.contentType().xContent(), + BytesReference.bytes(emptySubQueriesXContentBuilder) + ); + contentParser2.nextToken(); + + expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser2)); + } + + @SneakyThrows + public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); + + queryBuilder.add(neuralQueryBuilder); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + queryBuilder.add(termSubQuery); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map out = xContentBuilderToMap(builder); + + Object outer = out.get(HybridQueryBuilder.NAME); + if (!(outer instanceof Map)) { + fail("hybrid does not map to nested object"); + } + + Map outerMap = (Map) outer; + + assertNotNull(outerMap); + assertTrue(outerMap.containsKey("queries")); + assertTrue(outerMap.get("queries") instanceof List); + List listWithQueries = (List) outerMap.get("queries"); + assertEquals(2, listWithQueries.size()); + + // verify neural search query + Map vectorFieldInnerMap = getInnerMap(listWithQueries.get(0), NeuralQueryBuilder.NAME, VECTOR_FIELD_NAME); + assertEquals(MODEL_ID, vectorFieldInnerMap.get(MODEL_ID_FIELD.getPreferredName())); + assertEquals(QUERY_TEXT, vectorFieldInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); + assertEquals(K, vectorFieldInnerMap.get(K_FIELD.getPreferredName())); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + assertEquals( + xContentBuilderToMap(TEST_FILTER.toXContent(xContentBuilder, EMPTY_PARAMS)), + vectorFieldInnerMap.get(FILTER_FIELD.getPreferredName()) + ); + // verify term query + Map termFieldInnerMap = getInnerMap(listWithQueries.get(1), TermQueryBuilder.NAME, TEXT_FIELD_NAME); + assertEquals(TERM_QUERY_TEXT, termFieldInnerMap.get("value")); + } + + @SneakyThrows + public void testStreams_whenWrittingToStream_thenSuccessful() { + HybridQueryBuilder original = new HybridQueryBuilder(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + + original.add(neuralQueryBuilder); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + original.add(termSubQuery); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of( + new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new), + new NamedWriteableRegistry.Entry(QueryBuilder.class, NeuralQueryBuilder.NAME, NeuralQueryBuilder::new), + new NamedWriteableRegistry.Entry(QueryBuilder.class, HybridQueryBuilder.NAME, HybridQueryBuilder::new) + ) + ) + ); + + HybridQueryBuilder copy = new HybridQueryBuilder(filterStreamInput); + assertEquals(original, copy); + } + + public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { + HybridQueryBuilder hybridQueryBuilderBaseline = new HybridQueryBuilder(); + hybridQueryBuilderBaseline.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderBaseline.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + + HybridQueryBuilder hybridQueryBuilderBaselineCopy = new HybridQueryBuilder(); + hybridQueryBuilderBaselineCopy.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderBaselineCopy.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + + assertEquals(hybridQueryBuilderBaseline, hybridQueryBuilderBaseline); + assertEquals(hybridQueryBuilderBaseline.hashCode(), hybridQueryBuilderBaseline.hashCode()); + + assertEquals(hybridQueryBuilderBaselineCopy, hybridQueryBuilderBaselineCopy); + assertEquals(hybridQueryBuilderBaselineCopy.hashCode(), hybridQueryBuilderBaselineCopy.hashCode()); + } + + public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { + String modelId = "testModelId"; + String fieldName = "fieldTwo"; + String queryText = "query text"; + String termText = "another keyword"; + + HybridQueryBuilder hybridQueryBuilderBaseline = new HybridQueryBuilder(); + hybridQueryBuilderBaseline.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderBaseline.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + + HybridQueryBuilder hybridQueryBuilderOnlyOneSubQuery = new HybridQueryBuilder(); + hybridQueryBuilderOnlyOneSubQuery.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + + HybridQueryBuilder hybridQueryBuilderOnlyDifferentModelId = new HybridQueryBuilder(); + hybridQueryBuilderOnlyDifferentModelId.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(modelId) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderBaseline.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + + HybridQueryBuilder hybridQueryBuilderOnlyDifferentFieldName = new HybridQueryBuilder(); + hybridQueryBuilderOnlyDifferentFieldName.add( + new NeuralQueryBuilder().fieldName(fieldName) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderOnlyDifferentFieldName.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + + HybridQueryBuilder hybridQueryBuilderOnlyDifferentQuery = new HybridQueryBuilder(); + hybridQueryBuilderOnlyDifferentQuery.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(queryText) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderOnlyDifferentQuery.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + + HybridQueryBuilder hybridQueryBuilderOnlyDifferentTermValue = new HybridQueryBuilder(); + hybridQueryBuilderOnlyDifferentTermValue.add( + new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER) + ); + hybridQueryBuilderOnlyDifferentTermValue.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, termText)); + + assertNotEquals(hybridQueryBuilderBaseline, hybridQueryBuilderOnlyOneSubQuery); + assertNotEquals(hybridQueryBuilderBaseline.hashCode(), hybridQueryBuilderOnlyOneSubQuery.hashCode()); + + assertNotEquals(hybridQueryBuilderBaseline, hybridQueryBuilderOnlyDifferentModelId); + assertNotEquals(hybridQueryBuilderBaseline.hashCode(), hybridQueryBuilderOnlyDifferentModelId.hashCode()); + + assertNotEquals(hybridQueryBuilderBaseline, hybridQueryBuilderOnlyDifferentFieldName); + assertNotEquals(hybridQueryBuilderBaseline.hashCode(), hybridQueryBuilderOnlyDifferentFieldName.hashCode()); + + assertNotEquals(hybridQueryBuilderBaseline, hybridQueryBuilderOnlyDifferentQuery); + assertNotEquals(hybridQueryBuilderBaseline.hashCode(), hybridQueryBuilderOnlyDifferentQuery.hashCode()); + + assertNotEquals(hybridQueryBuilderBaseline, hybridQueryBuilderOnlyDifferentTermValue); + assertNotEquals(hybridQueryBuilderBaseline.hashCode(), hybridQueryBuilderOnlyDifferentTermValue.hashCode()); + } + + @SneakyThrows + public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + + queryBuilder.add(neuralQueryBuilder); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + queryBuilder.add(termSubQuery); + + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + Index dummyIndex = new Index("dummy", "dummy"); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + QueryBuilder queryBuilderAfterRewrite = queryBuilder.doRewrite(mockQueryShardContext); + assertTrue(queryBuilderAfterRewrite instanceof HybridQueryBuilder); + HybridQueryBuilder hybridQueryBuilder = (HybridQueryBuilder) queryBuilderAfterRewrite; + assertNotNull(hybridQueryBuilder.queries()); + assertEquals(2, hybridQueryBuilder.queries().size()); + List queryBuilders = hybridQueryBuilder.queries(); + // verify each sub-query builder + assertTrue(queryBuilders.get(0) instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilders.get(0); + assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); + assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertTrue(queryBuilders.get(1) instanceof TermQueryBuilder); + TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilders.get(1); + assertEquals(termSubQuery.fieldName(), termQueryBuilder.fieldName()); + assertEquals(termSubQuery.value(), termQueryBuilder.value()); + } + + /** + * Tests query with boost: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "keyword" + * } + * }, + * { + * "term": { + * "text": "keyword" + * } + * } + * ], + * "boost" : 2.0 + * } + * } + * } + */ + @SneakyThrows + public void testBoost_whenNonDefaultBoostSet_thenFail() { + XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .endArray() + .field("boost", 2.0f) + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilderWithNonDefaultBoost.contentType().xContent(), + BytesReference.bytes(xContentBuilderWithNonDefaultBoost) + ); + contentParser.nextToken(); + + ParsingException exception = expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + assertThat(exception.getMessage(), containsString("query does not support [boost]")); + } + + @SneakyThrows + public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { + // create query with 6 sub-queries, which is more than current max allowed + XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .startObject() + .startObject("term") + .field(TEXT_FIELD_NAME, RandomizedTest.randomAsciiAlphanumOfLength(10)) + .endObject() + .endObject() + .endArray() + .field("boost", DEFAULT_BOOST) + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilderWithNonDefaultBoost.contentType().xContent(), + BytesReference.bytes(xContentBuilderWithNonDefaultBoost) + ); + contentParser.nextToken(); + + HybridQueryBuilder hybridQueryBuilder = HybridQueryBuilder.fromXContent(contentParser); + assertNotNull(hybridQueryBuilder); + } + + private Map getInnerMap(Object innerObject, String queryName, String fieldName) { + if (!(innerObject instanceof Map)) { + fail("field name does not map to nested object"); + } + Map secondInnerMap = (Map) innerObject; + assertTrue(secondInnerMap.containsKey(queryName)); + assertTrue(secondInnerMap.get(queryName) instanceof Map); + Map neuralInnerMap = (Map) secondInnerMap.get(queryName); + assertTrue(neuralInnerMap.containsKey(fieldName)); + assertTrue(neuralInnerMap.get(fieldName) instanceof Map); + Map vectorFieldInnerMap = (Map) neuralInnerMap.get(fieldName); + return vectorFieldInnerMap; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java new file mode 100644 index 000000000..59c90f495 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -0,0 +1,410 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; + +import com.google.common.primitives.Floats; + +public class HybridQueryIT extends BaseNeuralSearchIT { + private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; + private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; + private static final int MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX = 3; + private static final String TEST_QUERY_TEXT = "greetings"; + private static final String TEST_QUERY_TEXT2 = "salute"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + /** + * Tests basic query, example of query structure: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBasicQuery_whenOneSubQuery_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + String modelId = getDeployedModelId(); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + + Map searchResponseAsMap1 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilder, 10); + + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap1).isPresent()); + } + + /** + * Tests complex query with multiple nested sub-queries: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "bool": { + * "should": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word2" + * } + * } + * ] + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); + + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + /** + * Using queries similar to below to test sub-queries order: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * }, + * { + * "term": { + * "text": "word" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + String modelId = getDeployedModelId(); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder); + + Map searchResponseAsMap1 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); + + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids1 = new ArrayList<>(); + List scores1 = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids1.add((String) oneHit.get("_id")); + scores1.add((Double) oneHit.get("_score")); + } + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores1.size() - 1).noneMatch(idx -> scores1.get(idx) < scores1.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids1).size(), ids1.size()); + + // check similar query when sub-queries are in reverse order, results must be same as in previous test case + HybridQueryBuilder hybridQueryBuilderTermThenNeural = new HybridQueryBuilder(); + hybridQueryBuilderTermThenNeural.add(termQueryBuilder); + hybridQueryBuilderTermThenNeural.add(neuralQueryBuilder); + + Map searchResponseAsMap2 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); + + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap2)); + + List ids2 = new ArrayList<>(); + List scores2 = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids2.add((String) oneHit.get("_id")); + scores2.add((Double) oneHit.get("_score")); + } + + Map total2 = getTotalHits(searchResponseAsMap2); + assertNotNull(total.get("value")); + assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total2.get("value")); + assertNotNull(total2.get("relation")); + assertEquals(RELATION_EQUAL_TO, total2.get("relation")); + // doc ids must match same from the previous query, order of sub-queries doesn't change the result + assertEquals(ids1, ids2); + assertEquals(scores1, scores2); + } + + @SneakyThrows + public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(termQuery2Builder); + + Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderOnlyTerm, 10); + + assertEquals(0, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isEmpty()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(0, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { + prepareKnnIndex( + TEST_BASIC_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) + ); + addKnnDoc( + TEST_BASIC_INDEX_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()) + ); + assertEquals(1, getDocCount(TEST_BASIC_INDEX_NAME)); + } + if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { + prepareKnnIndex( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "1", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector1).toArray(), Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "2", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector2).toArray(), Floats.asList(testVector2).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "3", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector3).toArray(), Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + assertEquals(3, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME)); + } + } + + private List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java new file mode 100644 index 000000000..62ddb64f6 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -0,0 +1,209 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Weight; +import org.apache.lucene.tests.util.TestUtil; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryScorerTests extends OpenSearchQueryTestCase { + + @SneakyThrows + public void testWithRandomDocuments_whenOneSubScorer_thenReturnSuccessfully() { + int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores = generateDocuments(maxDocId); + int[] docs = docsAndScores.getLeft(); + float[] scores = docsAndScores.getRight(); + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(scorer(docsAndScores.getLeft(), docsAndScores.getRight(), fakeWeight(new MatchAllDocsQuery()))) + ); + + testWithQuery(docs, scores, hybridQueryScorer); + } + + @SneakyThrows + public void testWithRandomDocumentsAndHybridScores_whenMultipleScorers_thenReturnSuccessfully() { + int maxDocId1 = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores1 = generateDocuments(maxDocId1); + int[] docs1 = docsAndScores1.getLeft(); + float[] scores1 = docsAndScores1.getRight(); + int maxDocId2 = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores2 = generateDocuments(maxDocId2); + int[] docs2 = docsAndScores2.getLeft(); + float[] scores2 = docsAndScores2.getRight(); + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())), + scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery())) + ) + ); + int doc = -1; + int numOfActualDocs = 0; + Set uniqueDocs1 = Arrays.stream(docs1).boxed().collect(Collectors.toSet()); + Set uniqueDocs2 = Arrays.stream(docs2).boxed().collect(Collectors.toSet()); + while (doc != NO_MORE_DOCS) { + doc = hybridQueryScorer.iterator().nextDoc(); + if (doc == DocIdSetIterator.NO_MORE_DOCS) { + continue; + } + float[] actualTotalScores = hybridQueryScorer.hybridScores(); + float actualTotalScore = 0.0f; + for (float score : actualTotalScores) { + actualTotalScore += score; + } + float expectedScore = 0.0f; + if (uniqueDocs1.contains(doc)) { + int idx = Arrays.binarySearch(docs1, doc); + expectedScore += scores1[idx]; + } + if (uniqueDocs2.contains(doc)) { + int idx = Arrays.binarySearch(docs2, doc); + expectedScore += scores2[idx]; + } + assertEquals(expectedScore, actualTotalScore, 0.001f); + numOfActualDocs++; + } + + int totalUniqueCount = uniqueDocs1.size(); + for (int n : uniqueDocs2) { + if (!uniqueDocs1.contains(n)) { + totalUniqueCount++; + } + } + assertEquals(totalUniqueCount, numOfActualDocs); + } + + @SneakyThrows + public void testWithRandomDocumentsAndCombinedScore_whenMultipleScorers_thenReturnSuccessfully() { + int maxDocId1 = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores1 = generateDocuments(maxDocId1); + int[] docs1 = docsAndScores1.getLeft(); + float[] scores1 = docsAndScores1.getRight(); + int maxDocId2 = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores2 = generateDocuments(maxDocId2); + int[] docs2 = docsAndScores2.getLeft(); + float[] scores2 = docsAndScores2.getRight(); + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())), + scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery())) + ) + ); + int doc = -1; + int numOfActualDocs = 0; + Set uniqueDocs1 = Arrays.stream(docs1).boxed().collect(Collectors.toSet()); + Set uniqueDocs2 = Arrays.stream(docs2).boxed().collect(Collectors.toSet()); + while (doc != NO_MORE_DOCS) { + doc = hybridQueryScorer.iterator().nextDoc(); + if (doc == DocIdSetIterator.NO_MORE_DOCS) { + continue; + } + float expectedScore = 0.0f; + if (uniqueDocs1.contains(doc)) { + int idx = Arrays.binarySearch(docs1, doc); + expectedScore += scores1[idx]; + } + if (uniqueDocs2.contains(doc)) { + int idx = Arrays.binarySearch(docs2, doc); + expectedScore += scores2[idx]; + } + assertEquals(expectedScore, hybridQueryScorer.score(), 0.001f); + numOfActualDocs++; + } + + int totalUniqueCount = uniqueDocs1.size(); + for (int n : uniqueDocs2) { + if (!uniqueDocs1.contains(n)) { + totalUniqueCount++; + } + } + assertEquals(totalUniqueCount, numOfActualDocs); + } + + @SneakyThrows + public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenReturnSuccessfully() { + int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores = generateDocuments(maxDocId); + int[] docs = docsAndScores.getLeft(); + float[] scores = docsAndScores.getRight(); + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null) + ); + + testWithQuery(docs, scores, hybridQueryScorer); + } + + private Pair generateDocuments(int maxDocId) { + final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); + final int[] docs = new int[numDocs]; + final Set uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDocId)); + } + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores = new float[numDocs]; + for (int j = 0; j < numDocs; ++j) { + scores[j] = random().nextFloat(); + } + return new ImmutablePair<>(docs, scores); + } + + private void testWithQuery(int[] docs, float[] scores, HybridQueryScorer hybridQueryScorer) throws IOException { + int doc = -1; + int numOfActualDocs = 0; + while (doc != NO_MORE_DOCS) { + int target = doc + 1; + doc = hybridQueryScorer.iterator().nextDoc(); + int idx = Arrays.binarySearch(docs, target); + idx = (idx >= 0) ? idx : (-1 - idx); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores[idx], hybridQueryScorer.score(), 0f); + numOfActualDocs++; + } + } + assertEquals(docs.length, numOfActualDocs); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java new file mode 100644 index 000000000..4d24c2049 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -0,0 +1,278 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.query.HybridQueryBuilderTests.QUERY_TEXT; +import static org.opensearch.neuralsearch.query.HybridQueryBuilderTests.TEXT_FIELD_NAME; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +import lombok.SneakyThrows; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.search.QueryUtils; +import org.opensearch.core.index.Index; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.KNNQueryBuilder; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryTests extends OpenSearchQueryTestCase { + + static final String VECTOR_FIELD_NAME = "vectorField"; + static final String TERM_QUERY_TEXT = "keyword"; + static final String TERM_ANOTHER_QUERY_TEXT = "anotherkeyword"; + static final float[] VECTOR_QUERY = new float[] { 1.0f, 2.0f, 2.1f, 0.6f }; + static final int K = 2; + + @SneakyThrows + public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery query1 = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + HybridQuery query2 = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + HybridQuery query3 = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) + ) + ); + QueryUtils.check(query1); + QueryUtils.checkEqual(query1, query2); + QueryUtils.checkUnequal(query1, query3); + + Iterator queryIterator = query3.iterator(); + assertNotNull(queryIterator); + int countOfQueries = 0; + while (queryIterator.hasNext()) { + Query query = queryIterator.next(); + assertNotNull(query); + countOfQueries++; + } + assertEquals(2, countOfQueries); + } + + @SneakyThrows + public void testRewrite_whenRewriteQuery_thenSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), RandomizedTest.randomAsciiAlphanumOfLength(8), ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), RandomizedTest.randomAsciiAlphanumOfLength(8), ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), RandomizedTest.randomAsciiAlphanumOfLength(8), ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + Query rewritten = hybridQueryWithTerm.rewrite(reader); + // term query is the same after we rewrite it + assertSame(hybridQueryWithTerm, rewritten); + + Index dummyIndex = new Index("dummy", "dummy"); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K); + Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext); + + HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery)); + rewritten = hybridQueryWithKnn.rewrite(reader); + assertSame(hybridQueryWithKnn, rewritten); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenReturnSuccessfully() { + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + String field1Value = "text1"; + String field2Value = "text2"; + String field3Value = "text3"; + + final Directory dir = newDirectory(); + final IndexWriter w = new IndexWriter(dir, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, field1Value, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, field2Value, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, field3Value, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + HybridQuery query = new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))) + ); + // executing search query, getting up to 3 docs in result + TopDocs hybridQueryResult = searcher.search(query, 3); + + assertNotNull(hybridQueryResult); + assertEquals(2, hybridQueryResult.scoreDocs.length); + List expectedDocIds = List.of(docId1, docId2); + List actualDocIds = Arrays.stream(hybridQueryResult.scoreDocs).map(scoreDoc -> { + try { + return reader.document(scoreDoc.doc).getField("id").stringValue(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).map(Integer::parseInt).collect(Collectors.toList()); + assertEquals(actualDocIds, expectedDocIds); + assertFalse(actualDocIds.contains(docId3)); + w.close(); + reader.close(); + dir.close(); + } + + @SneakyThrows + public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSuccessfully() { + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + String field1Value = "text1"; + String field2Value = "text2"; + String field3Value = "text3"; + + final Directory dir = newDirectory(); + final IndexWriter w = new IndexWriter(dir, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, field1Value, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, field2Value, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, field3Value, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)))); + // executing search query, getting up to 3 docs in result + TopDocs hybridQueryResult = searcher.search(query, 3); + + assertNotNull(hybridQueryResult); + assertEquals(0, hybridQueryResult.scoreDocs.length); + w.close(); + reader.close(); + dir.close(); + } + + @SneakyThrows + public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenReturnSuccessfully() { + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + String field1Value = "text1"; + String field2Value = "text2"; + String field3Value = "text3"; + + final Directory dir = newDirectory(); + final IndexWriter w = new IndexWriter(dir, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, field1Value, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, field2Value, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, field3Value, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + HybridQuery query = new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))) + ); + // executing search query, getting up to 3 docs in result + TopDocs hybridQueryResult = searcher.search(query, 3); + + assertNotNull(hybridQueryResult); + assertEquals(0, hybridQueryResult.scoreDocs.length); + w.close(); + reader.close(); + dir.close(); + } + + @SneakyThrows + public void testWithRandomDocuments_whenNoSubQueries_thenFail() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); + } + + @SneakyThrows + public void testToString_whenCallQueryToString_thenSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery query = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext), + new BoolQueryBuilder().should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) + .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) + .toQuery(mockQueryShardContext) + ) + ); + + String queryString = query.toString(TEXT_FIELD_NAME); + assertEquals("(keyword | anotherkeyword | (keyword anotherkeyword))", queryString); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java new file mode 100644 index 000000000..c876621a2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.query.HybridQueryBuilderTests.TEXT_FIELD_NAME; + +import java.util.List; + +import lombok.SneakyThrows; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Matches; +import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryWeightTests extends OpenSearchQueryTestCase { + + static final String TERM_QUERY_TEXT = "keyword"; + + @SneakyThrows + public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId, TERM_QUERY_TEXT, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + IndexSearcher searcher = newSearcher(reader); + Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); + + assertNotNull(weight); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + Scorer scorer = weight.scorer(leafReaderContext); + + assertNotNull(scorer); + + DocIdSetIterator iterator = scorer.iterator(); + int actualDoc = iterator.nextDoc(); + int actualDocId = Integer.parseInt(reader.document(actualDoc).getField("id").stringValue()); + + assertEquals(docId, actualDocId); + + assertTrue(weight.isCacheable(leafReaderContext)); + + Matches matches = weight.matches(leafReaderContext, actualDoc); + MatchesIterator matchesIterator = matches.getMatches(TEXT_FIELD_NAME); + assertTrue(matchesIterator.next()); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testExplain_whenCallExplain_thenFail() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId, RandomizedTest.randomAsciiAlphanumOfLength(8), ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + IndexSearcher searcher = newSearcher(reader); + Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); + + assertNotNull(weight); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + expectThrows(UnsupportedOperationException.class, () -> weight.explain(leafReaderContext, docId)); + + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index 5c7ec335c..c55da9e7f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -12,10 +12,10 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; import lombok.SneakyThrows; +import org.junit.After; import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -39,13 +39,24 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; - private static final AtomicReference modelId = new AtomicReference<>(); private final float[] testVector = createRandomVector(TEST_DIMENSION); @Before public void setUp() throws Exception { super.setUp(); - modelId.compareAndSet(modelId.get(), prepareModel()); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); } /** @@ -65,10 +76,11 @@ public void setUp() throws Exception { @SneakyThrows public void testBasicQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -77,7 +89,7 @@ public void testBasicQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -99,10 +111,11 @@ public void testBasicQuery() { @SneakyThrows public void testBoostQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -114,7 +127,7 @@ public void testBoostQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -141,11 +154,12 @@ public void testBoostQuery() { @SneakyThrows public void testRescoreQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); NeuralQueryBuilder rescoreNeuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -155,7 +169,7 @@ public void testRescoreQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -187,12 +201,13 @@ public void testRescoreQuery() { @SneakyThrows public void testBooleanQuery_withMultipleNeuralQueries() { initializeIndexIfNotExist(TEST_MULTI_VECTOR_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -200,7 +215,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_2, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -212,7 +227,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -242,12 +257,13 @@ public void testBooleanQuery_withMultipleNeuralQueries() { @SneakyThrows public void testBooleanQuery_withNeuralAndBM25Queries() { initializeIndexIfNotExist(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -261,7 +277,7 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float minExpectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float minExpectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); } @@ -286,11 +302,12 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { @SneakyThrows public void testNestedQuery() { initializeIndexIfNotExist(TEST_NESTED_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_NESTED, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -300,7 +317,7 @@ public void testNestedQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -328,10 +345,11 @@ public void testNestedQuery() { @SneakyThrows public void testFilterQuery() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, new MatchQueryBuilder("_id", "3") @@ -340,7 +358,7 @@ public void testFilterQuery() { assertEquals(1, getHitCount(searchResponseAsMap)); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("3", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java new file mode 100644 index 000000000..94866acb8 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -0,0 +1,232 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static java.util.stream.Collectors.toList; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.compress.CompressedXContent; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.analysis.AnalyzerScope; +import org.opensearch.index.analysis.IndexAnalyzers; +import org.opensearch.index.analysis.NamedAnalyzer; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.similarity.SimilarityService; +import org.opensearch.indices.IndicesModule; +import org.opensearch.indices.mapper.MapperRegistry; +import org.opensearch.plugins.MapperPlugin; +import org.opensearch.plugins.ScriptPlugin; +import org.opensearch.script.ScriptModule; +import org.opensearch.script.ScriptService; +import org.opensearch.test.OpenSearchTestCase; + +public abstract class OpenSearchQueryTestCase extends OpenSearchTestCase { + + protected final MapperService createMapperService(Version version, XContentBuilder mapping) throws IOException { + IndexMetadata meta = IndexMetadata.builder("index") + .settings(Settings.builder().put("index.version.created", version)) + .numberOfReplicas(0) + .numberOfShards(1) + .build(); + IndexSettings indexSettings = new IndexSettings(meta, getIndexSettings()); + MapperRegistry mapperRegistry = new IndicesModule( + Stream.of().filter(p -> p instanceof MapperPlugin).map(p -> (MapperPlugin) p).collect(toList()) + ).getMapperRegistry(); + ScriptModule scriptModule = new ScriptModule( + Settings.EMPTY, + Stream.of().filter(p -> p instanceof ScriptPlugin).map(p -> (ScriptPlugin) p).collect(toList()) + ); + ScriptService scriptService = new ScriptService(getIndexSettings(), scriptModule.engines, scriptModule.contexts); + SimilarityService similarityService = new SimilarityService(indexSettings, scriptService, emptyMap()); + MapperService mapperService = new MapperService( + indexSettings, + createIndexAnalyzers(indexSettings), + xContentRegistry(), + similarityService, + mapperRegistry, + () -> { throw new UnsupportedOperationException(); }, + () -> true, + scriptService + ); + merge(mapperService, mapping); + return mapperService; + } + + protected Settings getIndexSettings() { + return Settings.builder().put("index.version.created", Version.CURRENT).build(); + } + + protected IndexAnalyzers createIndexAnalyzers(IndexSettings indexSettings) { + return new IndexAnalyzers( + singletonMap("default", new NamedAnalyzer("default", AnalyzerScope.INDEX, new StandardAnalyzer())), + emptyMap(), + emptyMap() + ); + } + + protected final void merge(MapperService mapperService, XContentBuilder mapping) throws IOException { + mapperService.merge("_doc", new CompressedXContent(BytesReference.bytes(mapping)), MapperService.MergeReason.MAPPING_UPDATE); + } + + protected final XContentBuilder fieldMapping(CheckedConsumer buildField) throws IOException { + return mapping(b -> { + b.startObject("field"); + buildField.accept(b); + b.endObject(); + }); + } + + protected final XContentBuilder mapping(CheckedConsumer buildFields) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("_doc").startObject("properties"); + buildFields.accept(builder); + return builder.endObject().endObject().endObject(); + } + + protected MapperService createMapperService(XContentBuilder mappings) throws IOException { + return createMapperService(Version.CURRENT, mappings); + } + + protected MapperService createMapperService() throws IOException { + return createMapperService( + fieldMapping( + b -> b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject() + ) + ); + } + + protected static Document getDocument(String fieldName, int docId, String fieldValue, FieldType ft) { + Document doc = new Document(); + doc.add(new TextField("id", Integer.toString(docId), Field.Store.YES)); + doc.add(new Field(fieldName, fieldValue, ft)); + return doc; + } + + protected static Weight fakeWeight(Query query) { + return new Weight(query) { + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + static DocIdSetIterator iterator(final int... docs) { + return new DocIdSetIterator() { + + int i = -1; + + @Override + public int nextDoc() { + if (i + 1 == docs.length) { + return NO_MORE_DOCS; + } else { + return docs[++i]; + } + } + + @Override + public int docID() { + return i < 0 ? -1 : i == docs.length ? NO_MORE_DOCS : docs[i]; + } + + @Override + public long cost() { + return docs.length; + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + }; + } + + protected static Scorer scorer(final int[] docs, List scores, Weight weight) { + float[] scoresAsArray = new float[scores.size()]; + int i = 0; + for (float score : scores) { + scoresAsArray[i++] = score; + } + return scorer(docs, scoresAsArray, weight); + } + + protected static Scorer scorer(final int[] docs, final float[] scores, Weight weight) { + final DocIdSetIterator iterator = iterator(docs); + return new Scorer(weight) { + + int lastScoredDoc = -1; + + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public float score() { + assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID()); + lastScoredDoc = docID(); + final int idx = Arrays.binarySearch(docs, docID()); + return scores[idx]; + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; + } + }; + } + + @SuppressForbidden(reason = "manipulates system properties for testing") + protected static void initFeatureFlags() { + System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey(), "true"); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java new file mode 100644 index 000000000..0c79d7f73 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search; + +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.lang3.RandomUtils; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class CompoundTopDocsTests extends OpenSearchQueryTestCase { + + public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { + TopDocs topDocs1 = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, RandomUtils.nextFloat()), new ScoreDoc(1, RandomUtils.nextFloat()) } + ); + TopDocs topDocs2 = new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, RandomUtils.nextFloat()), + new ScoreDoc(4, RandomUtils.nextFloat()), + new ScoreDoc(5, RandomUtils.nextFloat()) } + ); + List topDocs = List.of(topDocs1, topDocs2); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs); + assertNotNull(compoundTopDocs); + assertEquals(topDocs, compoundTopDocs.getCompoundTopDocs()); + } + + public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() { + CompoundTopDocs hybridQueryScoreTopDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, RandomUtils.nextFloat()), + new ScoreDoc(4, RandomUtils.nextFloat()), + new ScoreDoc(5, RandomUtils.nextFloat()) } + ); + assertNotNull(hybridQueryScoreTopDocs); + assertNull(hybridQueryScoreTopDocs.getCompoundTopDocs()); + } + + public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWithMostHits() { + TopDocs topDocs1 = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), null); + TopDocs topDocs2 = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, RandomUtils.nextFloat()), new ScoreDoc(4, RandomUtils.nextFloat()) } + ); + List topDocs = List.of(topDocs1, topDocs2); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs); + assertNotNull(compoundTopDocs); + assertNotNull(compoundTopDocs.scoreDocs); + assertEquals(2, compoundTopDocs.scoreDocs.length); + } + + public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null); + assertNotNull(compoundTopDocs); + assertNull(compoundTopDocs.scoreDocs); + + CompoundTopDocs compoundTopDocsWithNullArray = new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + Arrays.asList(null, null) + ); + assertNotNull(compoundTopDocsWithNullArray); + assertNotNull(compoundTopDocsWithNullArray.scoreDocs); + assertEquals(0, compoundTopDocsWithNullArray.scoreDocs.length); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java b/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java new file mode 100644 index 000000000..0a6a12c88 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search; + +import java.util.stream.IntStream; + +import org.apache.lucene.search.ScoreMode; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class HitsTresholdCheckerTests extends OpenSearchQueryTestCase { + + public void testTresholdReached_whenIncrementCount_thenTresholdReached() { + HitsThresholdChecker hitsThresholdChecker = new HitsThresholdChecker(5); + assertEquals(5, hitsThresholdChecker.getTotalHitsThreshold()); + assertEquals(ScoreMode.TOP_SCORES, hitsThresholdChecker.scoreMode()); + assertFalse(hitsThresholdChecker.isThresholdReached()); + hitsThresholdChecker.incrementHitCount(); + assertFalse(hitsThresholdChecker.isThresholdReached()); + IntStream.rangeClosed(1, 5).forEach((checker) -> hitsThresholdChecker.incrementHitCount()); + assertTrue(hitsThresholdChecker.isThresholdReached()); + } + + public void testTresholdLimit_whenThresholdNegative_thenFail() { + expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(-1)); + } + + public void testTresholdLimit_whenThresholdMaxValue_thenFail() { + expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(Integer.MAX_VALUE)); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java new file mode 100644 index 000000000..747b05992 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -0,0 +1,352 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.RandomUtils; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.neuralsearch.query.HybridQueryScorer; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase { + + static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_QUERY_TEXT = "greeting"; + private static final String TEST_QUERY_TEXT2 = "salute"; + private static final int NUM_DOCS = 4; + private static final int TOTAL_HITS_UP_TO = 1000; + + private static final int DOC_ID_1 = RandomUtils.nextInt(0, 100_000); + private static final int DOC_ID_2 = RandomUtils.nextInt(0, 100_000); + private static final int DOC_ID_3 = RandomUtils.nextInt(0, 100_000); + private static final int DOC_ID_4 = RandomUtils.nextInt(0, 100_000); + private static final String FIELD_1_VALUE = "text1"; + private static final String FIELD_2_VALUE = "text2"; + private static final String FIELD_3_VALUE = "text3"; + private static final String FIELD_4_VALUE = "text4"; + + @SneakyThrows + public void testBasics_whenCreateNewCollector_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO) + ); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + assertEquals(ScoreMode.TOP_SCORES, hybridTopScoreDocCollector.scoreMode()); + + Weight weight = mock(Weight.class); + hybridTopScoreDocCollector.setWeight(weight); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testGetHybridScores_whenCreateNewAndGetScores_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO) + ); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + int[] docIds = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3 }; + Arrays.sort(docIds); + final List scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(scorer(docIds, scores, fakeWeight(new MatchAllDocsQuery()))) + ); + + leafCollector.setScorer(hybridQueryScorer); + List hybridScores = new ArrayList<>(); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + int nextDoc = iterator.nextDoc(); + while (nextDoc != NO_MORE_DOCS) { + hybridScores.add(hybridQueryScorer.hybridScores()); + nextDoc = iterator.nextDoc(); + } + // assert + assertEquals(3, hybridScores.size()); + assertFalse(hybridScores.stream().anyMatch(score -> score[0] <= 0.0)); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testTopDocs_whenCreateNewAndGetTopDocs_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO) + ); + Weight weight = mock(Weight.class); + hybridTopScoreDocCollector.setWeight(weight); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + int[] docIds = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3 }; + Arrays.sort(docIds); + final List scores1 = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + final List scores2 = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorer( + docIds, + scores1, + fakeWeight(QueryBuilders.termQuery(TEXT_FIELD_NAME, TEST_QUERY_TEXT).toQuery(mockQueryShardContext)) + ), + scorer( + docIds, + scores2, + fakeWeight(QueryBuilders.termQuery(TEXT_FIELD_NAME, TEST_QUERY_TEXT2).toQuery(mockQueryShardContext)) + ) + ) + ); + + leafCollector.setScorer(hybridQueryScorer); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + + int doc = iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + leafCollector.collect(doc); + doc = iterator.nextDoc(); + } + + List topDocs = hybridTopScoreDocCollector.topDocs(); + + assertNotNull(topDocs); + assertEquals(2, topDocs.size()); + + for (TopDocs topDoc : topDocs) { + // assert results for each sub-query, there must be correct number of matches, all doc id are correct and scores must be desc + // ordered + assertEquals(3, topDoc.totalHits.value); + ScoreDoc[] scoreDocs = topDoc.scoreDocs; + assertNotNull(scoreDocs); + assertEquals(3, scoreDocs.length); + assertTrue(IntStream.range(0, scoreDocs.length - 1).noneMatch(idx -> scoreDocs[idx].score < scoreDocs[idx + 1].score)); + List resultDocIds = Arrays.stream(scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); + assertTrue(Arrays.stream(docIds).allMatch(resultDocIds::contains)); + } + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_4, FIELD_4_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO) + ); + Weight weight = mock(Weight.class); + hybridTopScoreDocCollector.setWeight(weight); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + int[] docIdsQuery1 = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3 }; + int[] docIdsQuery2 = new int[] { DOC_ID_4, DOC_ID_1 }; + int[] docIdsQueryMatchAll = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3, DOC_ID_4 }; + Arrays.sort(docIdsQuery1); + Arrays.sort(docIdsQuery2); + Arrays.sort(docIdsQueryMatchAll); + final List scores1 = Stream.generate(() -> random().nextFloat()).limit(docIdsQuery1.length).collect(Collectors.toList()); + final List scores2 = Stream.generate(() -> random().nextFloat()).limit(docIdsQuery2.length).collect(Collectors.toList()); + final List scoresMatchAll = Stream.generate(() -> random().nextFloat()) + .limit(docIdsQueryMatchAll.length) + .collect(Collectors.toList()); + + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorer( + docIdsQuery1, + scores1, + fakeWeight(QueryBuilders.termQuery(TEXT_FIELD_NAME, TEST_QUERY_TEXT).toQuery(mockQueryShardContext)) + ), + scorer( + docIdsQuery2, + scores2, + fakeWeight(QueryBuilders.termQuery(TEXT_FIELD_NAME, TEST_QUERY_TEXT2).toQuery(mockQueryShardContext)) + ), + scorer(new int[0], new float[0], fakeWeight(new MatchNoDocsQuery())), + scorer(docIdsQueryMatchAll, scoresMatchAll, fakeWeight(new MatchAllDocsQuery())) + ) + ); + + leafCollector.setScorer(hybridQueryScorer); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + + int doc = iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + leafCollector.collect(doc); + doc = iterator.nextDoc(); + } + + List topDocs = hybridTopScoreDocCollector.topDocs(); + + assertNotNull(topDocs); + assertEquals(4, topDocs.size()); + + // assert result for each sub query + // term query 1 + TopDocs topDocQuery1 = topDocs.get(0); + assertEquals(docIdsQuery1.length, topDocQuery1.totalHits.value); + ScoreDoc[] scoreDocsQuery1 = topDocQuery1.scoreDocs; + assertNotNull(scoreDocsQuery1); + assertEquals(docIdsQuery1.length, scoreDocsQuery1.length); + assertTrue( + IntStream.range(0, scoreDocsQuery1.length - 1).noneMatch(idx -> scoreDocsQuery1[idx].score < scoreDocsQuery1[idx + 1].score) + ); + List resultDocIdsQuery1 = Arrays.stream(scoreDocsQuery1).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); + assertTrue(Arrays.stream(docIdsQuery1).allMatch(resultDocIdsQuery1::contains)); + // term query 2 + TopDocs topDocQuery2 = topDocs.get(1); + assertEquals(docIdsQuery2.length, topDocQuery2.totalHits.value); + ScoreDoc[] scoreDocsQuery2 = topDocQuery2.scoreDocs; + assertNotNull(scoreDocsQuery2); + assertEquals(docIdsQuery2.length, scoreDocsQuery2.length); + assertTrue( + IntStream.range(0, scoreDocsQuery2.length - 1).noneMatch(idx -> scoreDocsQuery2[idx].score < scoreDocsQuery2[idx + 1].score) + ); + List resultDocIdsQuery2 = Arrays.stream(scoreDocsQuery2).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()); + assertTrue(Arrays.stream(docIdsQuery2).allMatch(resultDocIdsQuery2::contains)); + // no match query + TopDocs topDocQuery3 = topDocs.get(2); + assertEquals(0, topDocQuery3.totalHits.value); + ScoreDoc[] scoreDocsQuery3 = topDocQuery3.scoreDocs; + assertNotNull(scoreDocsQuery3); + assertEquals(0, scoreDocsQuery3.length); + // all match query + TopDocs topDocQueryMatchAll = topDocs.get(3); + assertEquals(docIdsQueryMatchAll.length, topDocQueryMatchAll.totalHits.value); + ScoreDoc[] scoreDocsQueryMatchAll = topDocQueryMatchAll.scoreDocs; + assertNotNull(scoreDocsQueryMatchAll); + assertEquals(docIdsQueryMatchAll.length, scoreDocsQueryMatchAll.length); + assertTrue( + IntStream.range(0, scoreDocsQueryMatchAll.length - 1) + .noneMatch(idx -> scoreDocsQueryMatchAll[idx].score < scoreDocsQueryMatchAll[idx + 1].score) + ); + List resultDocIdsQueryMatchAll = Arrays.stream(scoreDocsQueryMatchAll) + .map(scoreDoc -> scoreDoc.doc) + .collect(Collectors.toList()); + assertTrue(Arrays.stream(docIdsQueryMatchAll).allMatch(resultDocIdsQueryMatchAll::contains)); + + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java new file mode 100644 index 000000000..f63df42c9 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -0,0 +1,418 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.query; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import lombok.SneakyThrows; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { + private static final String VECTOR_FIELD_NAME = "vectorField"; + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "This is really nice place to be"; + private static final String QUERY_TEXT1 = "hello"; + private static final String QUERY_TEXT2 = "randomkeyword"; + private static final String QUERY_TEXT3 = "place"; + private static final Index dummyIndex = new Index("dummy", "dummy"); + private static final String MODEL_ID = "mfgfgdsfgfdgsde"; + private static final int K = 10; + private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + + @SneakyThrows + public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + + @SneakyThrows + public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + + Query query = termSubQuery.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + + @SneakyThrows + public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertEquals(1, topDocs.totalHits.value); + assertTrue(topDocs instanceof CompoundTopDocs); + List compoundTopDocs = ((CompoundTopDocs) topDocs).getCompoundTopDocs(); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + TopDocs subQueryTopDocs = compoundTopDocs.get(0); + assertEquals(1, subQueryTopDocs.totalHits.value); + assertNotNull(subQueryTopDocs.scoreDocs); + assertEquals(1, subQueryTopDocs.scoreDocs.length); + ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[0]; + assertNotNull(scoreDoc); + int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); + assertEquals(docId1, actualDocId); + assertTrue(scoreDoc.score > 0.0f); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridResultsAreSet() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.add(QueryBuilders.matchAllQuery()); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertEquals(4, topDocs.totalHits.value); + assertTrue(topDocs instanceof CompoundTopDocs); + List compoundTopDocs = ((CompoundTopDocs) topDocs).getCompoundTopDocs(); + assertNotNull(compoundTopDocs); + assertEquals(3, compoundTopDocs.size()); + + TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); + List expectedIds1 = List.of(docId1); + assertQueryResults(subQueryTopDocs1, expectedIds1, reader); + + TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); + List expectedIds2 = List.of(); + assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + + TopDocs subQueryTopDocs3 = compoundTopDocs.get(2); + List expectedIds3 = List.of(docId1, docId2, docId3, docId4); + assertQueryResults(subQueryTopDocs3, expectedIds3, reader); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { + assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); + assertNotNull(subQueryTopDocs.scoreDocs); + assertEquals(expectedDocIds.size(), subQueryTopDocs.scoreDocs.length); + assertEquals(TotalHits.Relation.EQUAL_TO, subQueryTopDocs.totalHits.relation); + for (int i = 0; i < expectedDocIds.size(); i++) { + int expectedDocId = expectedDocIds.get(i); + ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[i]; + assertNotNull(scoreDoc); + int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); + assertEquals(expectedDocId, actualDocId); + assertTrue(scoreDoc.score > 0.0f); + } + } + + private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { + w.close(); + reader.close(); + directory.close(); + } +}