Skip to content

Commit

Permalink
Move rank constant param to factory
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 11, 2024
1 parent 0647ad6 commit 84a627c
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand All @@ -34,6 +33,4 @@ public class NormalizationExecuteDTO {
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
@Nullable
private int rankConstant;
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) {
final Optional<FetchSearchResult> fetchSearchResultOptional = normalizationExecuteDTO.getFetchSearchResultOptional();
final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDTO.getNormalizationTechnique();
final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDTO.getCombinationTechnique();
final int rankConstant = normalizationExecuteDTO.getRankConstant();
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

Expand All @@ -68,7 +67,6 @@ public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) {
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
.queryTopDocs(queryTopDocs)
.normalizationTechnique(normalizationTechnique)
.rankConstant(rankConstant)
.build();

// normalize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

import java.util.List;
Expand All @@ -24,6 +23,4 @@ public class NormalizeScoresDTO {
private List<CompoundTopDocs> queryTopDocs;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@Nullable
private int rankConstant;
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public <Result extends SearchPhaseResult> void process(
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.rankConstant(rankConstant)
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";
private final ScoreCombinationUtil scoreCombinationUtil;

// Not currently using weights for RRF, no need to modify or verify these params
public RRFScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
;
this.scoreCombinationUtil = combinationUtil;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ public class ScoreCombinationFactory {
Map.of(),
scoreCombinationUtil
);
public static final ScoreCombinationTechnique RRF_METHOD = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil);

private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ public SearchPhaseResultsProcessor create(

Map<String, Object> combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.RRF_METHOD;
ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination(
RRFScoreCombinationTechnique.TECHNIQUE_NAME
);
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = readStringProperty(
RRFProcessor.TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
package org.opensearch.neuralsearch.processor.normalization;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Locale;
import java.util.Set;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

Expand All @@ -22,14 +25,15 @@
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";
public static final int DEFAULT_RANK_CONSTANT = 60;
public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT);

private void validateRankConstant(final int rankConstant) {
boolean isOutOfRange = rankConstant < 1 || rankConstant >= Integer.MAX_VALUE;
if (isOutOfRange) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "rank constant must be >= 1 and < (2^31)-1, submitted rank constant: %d", rankConstant)
);
}
final int rankConstant;

public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNormalizationUtil scoreNormalizationUtil) {
scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS);
rankConstant = getRankConstant(params);
}

/**
Expand All @@ -46,8 +50,6 @@ private void validateRankConstant(final int rankConstant) {
@Override
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
final int rankConstant = normalizeScoresDTO.getRankConstant();
validateRankConstant(rankConstant);
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
Expand All @@ -56,11 +58,33 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
int numSubQueriesBound = topDocsPerSubQuery.size();
for (int index = 0; index < numSubQueriesBound; index++) {
int numDocsPerSubQueryBound = topDocsPerSubQuery.get(index).scoreDocs.length;
ScoreDoc[] scoreDocs = topDocsPerSubQuery.get(index).scoreDocs;
for (int j = 0; j < numDocsPerSubQueryBound; j++) {
topDocsPerSubQuery.get(index).scoreDocs[j].score = (1.f / (float) (rankConstant + j + 1));
scoreDocs[j].score = (1.f / (float) (rankConstant + j + 1));
}
}
}
}

private int getRankConstant(final Map<String, Object> params) {
if (params.containsKey(PARAM_NAME_RANK_CONSTANT)) {
int rankConstant = (int) params.get(PARAM_NAME_RANK_CONSTANT);
validateRankConstant(rankConstant);
return rankConstant;
}
return DEFAULT_RANK_CONSTANT;
}

private void validateRankConstant(final int rankConstant) {
boolean isOutOfRange = rankConstant < 1 || rankConstant >= 10000;
if (isOutOfRange) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"rank constant must be in the interval between 1 and 10.000, submitted rank constant: %d",
rankConstant
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@

import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

/**
* Abstracts creation of exact score normalization method based on technique name
*/
public class ScoreNormalizationFactory {

private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil();

public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique();

private final Map<String, ScoreNormalizationTechnique> scoreNormalizationMethodsMap = Map.of(
private final Map<String, Function<Map<String, Object>, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of(
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
new MinMaxScoreNormalizationTechnique(),
params -> new MinMaxScoreNormalizationTechnique(),
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
new L2ScoreNormalizationTechnique(),
params -> new L2ScoreNormalizationTechnique(),
RRFNormalizationTechnique.TECHNIQUE_NAME,
new RRFNormalizationTechnique()
params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil)
);

/**
Expand All @@ -29,7 +32,12 @@ public class ScoreNormalizationFactory {
* @return instance of ScoreNormalizationMethod for technique name
*/
public ScoreNormalizationTechnique createNormalization(final String technique) {
return createNormalization(technique, Map.of());
}

public ScoreNormalizationTechnique createNormalization(final String technique, final Map<String, Object> params) {
return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique))
.orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported"));
.orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported"))
.apply(params);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.math.DoubleMath;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.Range;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Collection of utility methods for score combination technique classes
*/
@Log4j2
class ScoreNormalizationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";
private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

/**
* Get collection of weights based on user provided config
* @param params map of named parameters and their values
* @return collection of weights
*/
public List<Float> getRankConstant(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
List<Float> weightsList = ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
validateWeights(weightsList);
return weightsList;
}

/**
* Validate config parameters for this technique
* @param actualParams map of parameters in form of name-value
* @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique
*/
public void validateParams(final Map<String, Object> actualParams, final Set<String> supportedParams) {
if (Objects.isNull(actualParams) || actualParams.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = actualParams.keySet()
.stream()
.filter(paramName -> !supportedParams.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
supportedParams.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param weights collection of weights for sub-queries
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}

/**
* Check if number of weights matches number of queries. This does not apply for case when
* weights were not provided, as this is valid default value
* @param scores collection of scores from all sub-queries of a single hybrid search query
* @param weights score combination weights that are defined as part of search result processor
*/
protected void validateIfWeightsMatchScores(final float[] scores, final List<Float> weights) {
if (weights.isEmpty()) {
return;
}
if (scores.length != weights.size()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"number of weights [%d] must match number of sub-queries [%d] in hybrid query",
weights.size(),
scores.length
)
);
}
}

/**
* Check if provided weights are valid for combination. Following conditions are checked:
* - every weight is between 0.0 and 1.0
* - sum of all weights must be equal 1.0
* @param weightsList
*/
private void validateWeights(final List<Float> weightsList) {
boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight));
if (isOutOfRange) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"all weights must be in range [0.0 ... 1.0], submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum);
if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"sum of weights for combination must be equal to 1.0, submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase {
public class ScoreNormalizationUtilTests extends OpenSearchQueryTestCase {

public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() {
ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
Expand Down
Loading

0 comments on commit 84a627c

Please sign in to comment.