From 1f67b94ee53bfe0824d3191b3ca35afa71a7a9e1 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 1 Aug 2023 23:12:56 +0200 Subject: [PATCH] Add harmonic mean combination (#238) * Add harmonic mean combination Signed-off-by: Martin Gaievski --- ...ithmeticMeanScoreCombinationTechnique.java | 70 ++------- ...HarmonicMeanScoreCombinationTechnique.java | 87 +++--------- .../combination/ScoreCombinationFactory.java | 10 +- .../ScoreCombinationTechnique.java | 1 + .../combination/ScoreCombinationUtil.java | 80 +++++++++++ .../common/BaseNeuralSearchIT.java | 6 +- .../NormalizationProcessorTests.java | 4 +- .../ScoreNormalizationCombinationIT.java | 133 +++++++++++++----- ...ticMeanScoreCombinationTechniqueTests.java | 55 +++++--- .../BaseScoreCombinationTechniqueTests.java | 57 ++++++++ ...nicMeanScoreCombinationTechniqueTests.java | 59 ++++++++ .../ScoreCombinationFactoryTests.java | 41 ++++++ 12 files changed, 412 insertions(+), 191 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 4eb9564e6..57040d2a1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -6,12 +6,8 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Abstracts combination of scores based on arithmetic mean method @@ -23,20 +19,12 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombination private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; - public ArithmeticMeanScoreCombinationTechnique(final Map params) { - validateParams(params); - weights = getWeights(params); - } - - private List getWeights(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return List.of(); - } - // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() - .map(Double::floatValue) - .collect(Collectors.toUnmodifiableList()); + public ArithmeticMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); } /** @@ -48,57 +36,19 @@ private List getWeights(final Map params) { @Override public float combine(final float[] scores) { float combinedScore = 0.0f; - float weights = 0; + float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; if (score >= 0.0) { - float weight = getWeightForSubQuery(indexOfSubQuery); + float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); score = score * weight; combinedScore += score; - weights += weight; + sumOfWeights += weight; } } - if (weights == 0.0f) { + if (sumOfWeights == 0.0f) { return ZERO_SCORE; } - return combinedScore / weights; - } - - private void validateParams(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return; - } - // check if only supported params are passed - Optional optionalNotSupportedParam = params.keySet() - .stream() - .filter(paramName -> !SUPPORTED_PARAMS.contains(paramName)) - .findFirst(); - if (optionalNotSupportedParam.isPresent()) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "provided parameter for combination technique is not supported. supported parameters are [%s]", - SUPPORTED_PARAMS.stream().collect(Collectors.joining(",")) - ) - ); - } - - // check param types - if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { - if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) - ); - } - } - } - - /** - * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise - * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query - * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default - */ - private float getWeightForSubQuery(int indexOfSubQuery) { - return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + return combinedScore / sumOfWeights; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 3fff2db2b..cb44e030a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -6,99 +6,46 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** - * Abstracts combination of scores based on arithmetic mean method + * Abstracts combination of scores based on harmonic mean method */ public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { - public static final String TECHNIQUE_NAME = "arithmetic_mean"; + public static final String TECHNIQUE_NAME = "harmonic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; - public HarmonicMeanScoreCombinationTechnique(final Map params) { - validateParams(params); - weights = getWeights(params); - } - - private List getWeights(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return List.of(); - } - // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() - .map(Double::floatValue) - .collect(Collectors.toUnmodifiableList()); + public HarmonicMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); } /** - * Arithmetic mean method for combining scores. - * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) + * Weighted harmonic mean method for combining scores. + * score = sum(weight_1 + .... + weight_n)/sum(weight_1/score_1 + ... + weight_n/score_n) * * Zero (0.0) scores are excluded from number of scores N */ @Override public float combine(final float[] scores) { - float combinedScore = 0.0f; - float weights = 0; + float sumOfWeights = 0; + float sumOfHarmonics = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; - if (score >= 0.0) { - float weight = getWeightForSubQuery(indexOfSubQuery); - score = score * weight; - combinedScore += score; - weights += weight; + if (score <= 0) { + continue; } + float weightOfSubQuery = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); + sumOfWeights += weightOfSubQuery; + sumOfHarmonics += weightOfSubQuery / score; } - if (weights == 0.0f) { - return ZERO_SCORE; - } - return combinedScore / weights; - } - - private void validateParams(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return; - } - // check if only supported params are passed - Optional optionalNotSupportedParam = params.keySet() - .stream() - .filter(paramName -> !SUPPORTED_PARAMS.contains(paramName)) - .findFirst(); - if (optionalNotSupportedParam.isPresent()) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "provided parameter for combination technique is not supported. supported parameters are [%s]", - SUPPORTED_PARAMS.stream().collect(Collectors.joining(",")) - ) - ); - } - - // check param types - if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { - if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) - ); - } - } - } - - /** - * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise - * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query - * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default - */ - private float getWeightForSubQuery(int indexOfSubQuery) { - return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 2f4804eb1..d034ede16 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -13,12 +13,18 @@ * Abstracts creation of exact score combination method based on technique name */ public class ScoreCombinationFactory { + private static final ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); - public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of()); + public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique( + Map.of(), + scoreCombinationUtil + ); private final Map, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of( ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, - ArithmeticMeanScoreCombinationTechnique::new + params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil), + HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java index 21090b1ce..6e0a5db65 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.processor.combination; public interface ScoreCombinationTechnique { + /** * Defines combination function specific to this technique * @param scores array of collected original scores diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java new file mode 100644 index 000000000..35e097f7f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Collection of utility methods for score combination technique classes + */ +class ScoreCombinationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + + /** + * Get collection of weights based on user provided config + * @param params map of named parameters and their values + * @return collection of weights + */ + public List getWeights(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return List.of(); + } + // get weights, we don't need to check for instance as it's done during validation + return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + .map(Double::floatValue) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + supportedParams.stream().collect(Collectors.joining(",")) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } + + /** + * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise + * @param weights collection of weights for sub-queries + * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query + * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default + */ + public float getWeightForSubQuery(final List weights, final int indexOfSubQuery) { + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 6aa9b5a5a..add0205b0 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -58,8 +58,8 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; private static final String DEFAULT_USER_AGENT = "Kibana"; - protected static final String NORMALIZATION_METHOD = "min_max"; - protected static final String COMBINATION_METHOD = "arithmetic_mean"; + protected static final String DEFAULT_NORMALIZATION_METHOD = "min_max"; + protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -556,7 +556,7 @@ public boolean isUpdateClusterSettings() { @SneakyThrows protected void createSearchPipelineWithResultsPostProcessor(final String pipelineId) { - createSearchPipeline(pipelineId, NORMALIZATION_METHOD, COMBINATION_METHOD, Map.of()); + createSearchPipeline(pipelineId, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 5c9d4b2b7..397642007 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -28,12 +28,12 @@ import org.opensearch.action.search.SearchPhaseName; import org.opensearch.action.search.SearchProgressListener; import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.breaker.CircuitBreaker; -import org.opensearch.common.breaker.NoopCircuitBreaker; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.index.shard.ShardId; import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java index 273dbd522..54271d042 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; -import org.opensearch.neuralsearch.processor.normalization.L2ScoreNormalizationTechnique; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -59,6 +58,10 @@ public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { private final static String RELATION_EQUAL_TO = "eq"; private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + @Before public void setUp() throws Exception { super.setUp(); @@ -276,8 +279,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { // check case when number of weights and sub-queries are same createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) ); @@ -300,8 +303,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { deleteSearchPipeline(SEARCH_PIPELINE); createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) ); @@ -320,8 +323,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { deleteSearchPipeline(SEARCH_PIPELINE); createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) ); @@ -340,8 +343,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { deleteSearchPipeline(SEARCH_PIPELINE); createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) ); @@ -379,8 +382,8 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); createSearchPipeline( SEARCH_PIPELINE, - L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - COMBINATION_METHOD, + L2_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) ); @@ -399,29 +402,66 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); int totalExpectedDocQty = 5; - assertNotNull(searchResponseAsMap); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(totalExpectedDocQty, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(Range.between(.6f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; + assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); + } - List> hitsNestedList = getNestedHits(searchResponseAsMap); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hitsNestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - // verify scores order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range - assertTrue(Range.between(.6f, 1.0f).contains((float) scores.stream().map(Double::floatValue).max(Double::compare).get())); + @SneakyThrows + public void testMinMaxNormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + int totalExpectedDocQty = 5; + float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; + assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); + } + + @SneakyThrows + public void testL2NormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + int totalExpectedDocQty = 5; + float[] minMaxExpectedScoresRange = { 0.5f, 1.0f }; + assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); } private void initializeIndexIfNotExist(String indexName) throws IOException { @@ -603,4 +643,33 @@ private void assertWeightedScores( assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); } + + private void assertHybridSearchResults(Map searchResponseAsMap, int totalExpectedDocQty, float[] minMaxScoreRange) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range + assertTrue( + Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) + .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) + ); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index d9f63d291..3c3ca3776 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -10,41 +10,52 @@ import java.util.List; import java.util.Map; -import org.opensearch.test.OpenSearchTestCase; +public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { -public class ArithmeticMeanScoreCombinationTechniqueTests extends OpenSearchTestCase { - - private static final float DELTA_FOR_ASSERTION = 0.0001f; + public ArithmeticMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = this::arithmeticMean; + } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); - float[] scores = { 1.0f, 0.5f, 0.3f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.6f, actualScore, DELTA_FOR_ASSERTION); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); - float[] scores = { 1.0f, -1.0f, 0.6f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.8f, actualScore, DELTA_FOR_ASSERTION); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique( - Map.of(PARAM_NAME_WEIGHTS, List.of(0.9, 0.2, 0.7)) + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() ); - float[] scores = { 1.0f, 0.5f, 0.3f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.6722f, actualScore, DELTA_FOR_ASSERTION); + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique( - Map.of(PARAM_NAME_WEIGHTS, List.of(0.9, 0.15, 0.7)) + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() ); - float[] scores = { 1.0f, -1.0f, 0.6f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.825f, actualScore, DELTA_FOR_ASSERTION); + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float arithmeticMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float sumOfWeightedScores = 0; + float sumOfWeights = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i); + float weight = weights.get(i).floatValue(); + if (score >= 0) { + sumOfWeightedScores += score * weight; + sumOfWeights += weight; + } + } + return sumOfWeights == 0 ? 0f : sumOfWeightedScores / sumOfWeights; } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..cf9d1080f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; + +import lombok.NoArgsConstructor; + +import org.apache.commons.lang.ArrayUtils; +import org.opensearch.test.OpenSearchTestCase; + +@NoArgsConstructor +public class BaseScoreCombinationTechniqueTests extends OpenSearchTestCase { + + protected BiFunction, List, Float> expectedScoreFunction; + + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(final ScoreCombinationTechnique technique) { + float[] scores = { 1.0f, 0.5f, 0.3f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), List.of(1.0, 1.0, 1.0)); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(final ScoreCombinationTechnique technique) { + float[] scores = { 1.0f, -1.0f, 0.6f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), List.of(1.0, 1.0, 1.0)); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + List weights + ) { + float[] scores = { 1.0f, 0.5f, 0.3f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + List weights + ) { + float[] scores = { 1.0f, -1.0f, 0.6f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..02a8084ef --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; + +public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + public HarmonicMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> harmonicMean(scores, weights); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() + ); + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() + ); + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float harmonicMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float w = 0, h = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i), weight = weights.get(i).floatValue(); + if (score > 0) { + w += weight; + h += weight / score; + } + } + return h == 0 ? 0f : w / h; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java new file mode 100644 index 000000000..9f164b3d3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.hamcrest.Matchers.containsString; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreCombinationFactoryTests extends OpenSearchQueryTestCase { + + public void testArithmeticWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("arithmetic_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof ArithmeticMeanScoreCombinationTechnique); + } + + public void testHarmonicWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("harmonic_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof HarmonicMeanScoreCombinationTechnique); + } + + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + IllegalArgumentException illegalArgumentException = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationFactory.createCombination("randomname") + ); + org.hamcrest.MatcherAssert.assertThat( + illegalArgumentException.getMessage(), + containsString("provided combination technique is not supported") + ); + } +}