From 84a627cdf39f066333defa373b223311d152594a Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 11 Oct 2024 11:06:11 -0700 Subject: [PATCH] Move rank constant param to factory Signed-off-by: Martin Gaievski --- .../processor/NormalizationExecuteDTO.java | 3 - .../NormalizationProcessorWorkflow.java | 2 - .../processor/NormalizeScoresDTO.java | 3 - .../neuralsearch/processor/RRFProcessor.java | 1 - .../RRFScoreCombinationTechnique.java | 3 +- .../combination/ScoreCombinationFactory.java | 1 - .../factory/RRFProcessorFactory.java | 4 +- .../RRFNormalizationTechnique.java | 44 ++++-- .../ScoreNormalizationFactory.java | 18 ++- .../normalization/ScoreNormalizationUtil.java | 139 ++++++++++++++++++ ....java => ScoreNormalizationUtilTests.java} | 2 +- .../RRFNormalizationTechniqueTests.java | 11 +- 12 files changed, 197 insertions(+), 34 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java rename src/test/java/org/opensearch/neuralsearch/processor/combination/{ScoreCombinationUtilTests.java => ScoreNormalizationUtilTests.java} (97%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java index 8f38df7bc..0cea24eec 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java @@ -8,7 +8,6 @@ import lombok.Builder; import lombok.Getter; import lombok.NonNull; -import org.opensearch.common.Nullable; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -34,6 +33,4 @@ public class NormalizationExecuteDTO { private ScoreNormalizationTechnique normalizationTechnique; @NonNull private ScoreCombinationTechnique combinationTechnique; - @Nullable - private int rankConstant; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index c64508be3..6507e3bd9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -56,7 +56,6 @@ public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) { final Optional fetchSearchResultOptional = normalizationExecuteDTO.getFetchSearchResultOptional(); final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDTO.getNormalizationTechnique(); final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDTO.getCombinationTechnique(); - final int rankConstant = normalizationExecuteDTO.getRankConstant(); // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); @@ -68,7 +67,6 @@ public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) { NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() .queryTopDocs(queryTopDocs) .normalizationTechnique(normalizationTechnique) - .rankConstant(rankConstant) .build(); // normalize diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java index af0441c18..c932a157d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java @@ -8,7 +8,6 @@ import lombok.Builder; import lombok.Getter; import lombok.NonNull; -import org.opensearch.common.Nullable; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import java.util.List; @@ -24,6 +23,4 @@ public class NormalizeScoresDTO { private List queryTopDocs; @NonNull private ScoreNormalizationTechnique normalizationTechnique; - @Nullable - private int rankConstant; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index 473a08ad3..b3a450f11 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -72,7 +72,6 @@ public void process( .fetchSearchResultOptional(fetchSearchResult) .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) - .rankConstant(rankConstant) .build(); normalizationWorkflow.execute(normalizationExecuteDTO); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java index f3f41c22a..0ec768ba0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -17,10 +17,11 @@ public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; + private final ScoreCombinationUtil scoreCombinationUtil; // Not currently using weights for RRF, no need to modify or verify these params public RRFScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { - ; + this.scoreCombinationUtil = combinationUtil; } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index edc5fa071..1e560342a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -18,7 +18,6 @@ public class ScoreCombinationFactory { Map.of(), scoreCombinationUtil ); - public static final ScoreCombinationTechnique RRF_METHOD = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); private final Map, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of( ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java index 2a34c9829..f7d075c16 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java @@ -68,7 +68,9 @@ public SearchPhaseResultsProcessor create( Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); - ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.RRF_METHOD; + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination( + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); if (Objects.nonNull(combinationClause)) { String combinationTechnique = readStringProperty( RRFProcessor.TYPE, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index c9df56e10..78ada034c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -5,9 +5,12 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Locale; +import java.util.Set; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; @@ -22,14 +25,15 @@ public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; + public static final int DEFAULT_RANK_CONSTANT = 60; + public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT); - private void validateRankConstant(final int rankConstant) { - boolean isOutOfRange = rankConstant < 1 || rankConstant >= Integer.MAX_VALUE; - if (isOutOfRange) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "rank constant must be >= 1 and < (2^31)-1, submitted rank constant: %d", rankConstant) - ); - } + final int rankConstant; + + public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); + rankConstant = getRankConstant(params); } /** @@ -46,8 +50,6 @@ private void validateRankConstant(final int rankConstant) { @Override public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); - final int rankConstant = normalizeScoresDTO.getRankConstant(); - validateRankConstant(rankConstant); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -56,11 +58,33 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { int numSubQueriesBound = topDocsPerSubQuery.size(); for (int index = 0; index < numSubQueriesBound; index++) { int numDocsPerSubQueryBound = topDocsPerSubQuery.get(index).scoreDocs.length; + ScoreDoc[] scoreDocs = topDocsPerSubQuery.get(index).scoreDocs; for (int j = 0; j < numDocsPerSubQueryBound; j++) { - topDocsPerSubQuery.get(index).scoreDocs[j].score = (1.f / (float) (rankConstant + j + 1)); + scoreDocs[j].score = (1.f / (float) (rankConstant + j + 1)); } } } } + private int getRankConstant(final Map params) { + if (params.containsKey(PARAM_NAME_RANK_CONSTANT)) { + int rankConstant = (int) params.get(PARAM_NAME_RANK_CONSTANT); + validateRankConstant(rankConstant); + return rankConstant; + } + return DEFAULT_RANK_CONSTANT; + } + + private void validateRankConstant(final int rankConstant) { + boolean isOutOfRange = rankConstant < 1 || rankConstant >= 10000; + if (isOutOfRange) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "rank constant must be in the interval between 1 and 10.000, submitted rank constant: %d", + rankConstant + ) + ); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 9e5de68f4..7c62893a5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -6,21 +6,24 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Function; /** * Abstracts creation of exact score normalization method based on technique name */ public class ScoreNormalizationFactory { + private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); - private final Map scoreNormalizationMethodsMap = Map.of( + private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - new MinMaxScoreNormalizationTechnique(), + params -> new MinMaxScoreNormalizationTechnique(), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - new L2ScoreNormalizationTechnique(), + params -> new L2ScoreNormalizationTechnique(), RRFNormalizationTechnique.TECHNIQUE_NAME, - new RRFNormalizationTechnique() + params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); /** @@ -29,7 +32,12 @@ public class ScoreNormalizationFactory { * @return instance of ScoreNormalizationMethod for technique name */ public ScoreNormalizationTechnique createNormalization(final String technique) { + return createNormalization(technique, Map.of()); + } + + public ScoreNormalizationTechnique createNormalization(final String technique, final Map params) { return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) - .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")) + .apply(params); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java new file mode 100644 index 000000000..556c8a6c6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import com.google.common.math.DoubleMath; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.Range; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Collection of utility methods for score combination technique classes + */ +@Log4j2 +class ScoreNormalizationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + + /** + * Get collection of weights based on user provided config + * @param params map of named parameters and their values + * @return collection of weights + */ + public List getRankConstant(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return List.of(); + } + // get weights, we don't need to check for instance as it's done during validation + List weightsList = ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + .map(Double::floatValue) + .collect(Collectors.toUnmodifiableList()); + validateWeights(weightsList); + return weightsList; + } + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + supportedParams.stream().collect(Collectors.joining(",")) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } + + /** + * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise + * @param weights collection of weights for sub-queries + * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query + * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default + */ + public float getWeightForSubQuery(final List weights, final int indexOfSubQuery) { + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + } + + /** + * Check if number of weights matches number of queries. This does not apply for case when + * weights were not provided, as this is valid default value + * @param scores collection of scores from all sub-queries of a single hybrid search query + * @param weights score combination weights that are defined as part of search result processor + */ + protected void validateIfWeightsMatchScores(final float[] scores, final List weights) { + if (weights.isEmpty()) { + return; + } + if (scores.length != weights.size()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "number of weights [%d] must match number of sub-queries [%d] in hybrid query", + weights.size(), + scores.length + ) + ); + } + } + + /** + * Check if provided weights are valid for combination. Following conditions are checked: + * - every weight is between 0.0 and 1.0 + * - sum of all weights must be equal 1.0 + * @param weightsList + */ + private void validateWeights(final List weightsList) { + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + if (isOutOfRange) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "all weights must be in range [0.0 ... 1.0], submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + } + float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "sum of weights for combination must be equal to 1.0, submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java similarity index 97% rename from src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java index 9e00e3833..009681116 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { +public class ScoreNormalizationUtilTests extends OpenSearchQueryTestCase { public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java index 1e5122eb5..bd3f7cbf0 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -12,6 +12,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import java.util.List; +import java.util.Map; /** * Abstracts testing of normalization of scores based on RRF method @@ -19,9 +20,10 @@ public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.001f; static final int RANK_CONSTANT = 60; + private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { - RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(); + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); Float[] scores = { 0.5f, 0.2f }; List compoundTopDocs = List.of( new CompoundTopDocs( @@ -38,7 +40,6 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() .queryTopDocs(compoundTopDocs) .normalizationTechnique(normalizationTechnique) - .rankConstant(RANK_CONSTANT) .build(); normalizationTechnique.normalize(normalizeScoresDTO); @@ -62,7 +63,7 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() } public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { - RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(); + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); Float[] scoresQuery1 = { 0.5f, 0.2f }; Float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; List compoundTopDocs = List.of( @@ -88,7 +89,6 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() .queryTopDocs(compoundTopDocs) .normalizationTechnique(normalizationTechnique) - .rankConstant(60) .build(); normalizationTechnique.normalize(normalizeScoresDTO); @@ -116,7 +116,7 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce } public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { - RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(); + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); Float[] scoresShard1Query1 = { 0.5f, 0.2f }; Float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; Float[] scoresShard2Query2 = { 2.9f, 0.7f }; @@ -162,7 +162,6 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() .queryTopDocs(compoundTopDocs) .normalizationTechnique(normalizationTechnique) - .rankConstant(60) .build(); normalizationTechnique.normalize(normalizeScoresDTO);