Skip to content

Commit

Permalink
Adding tests for a factory class and params
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 16, 2024
1 parent a07b379 commit 2734488
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@
@AllArgsConstructor
@Log4j2
public class RRFProcessorFactory implements Processor.Factory<SearchPhaseResultsProcessor> {
public static final String NORMALIZATION_CLAUSE = "combination";
public static final String COMBINATION_CLAUSE = "combination";
public static final String TECHNIQUE = "technique";
public static final String PARAMETERS = "parameters";
public static final int DEFAULT_RANK_CONSTANT = 60;

private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreNormalizationFactory scoreNormalizationFactory;
Expand All @@ -50,27 +48,14 @@ public SearchPhaseResultsProcessor create(
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) throws Exception {
Map<String, Object> normalizationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE);
// reads parameters passed in from user to get rank constant to be used in RRFNormalizationTechnique
Map<String, Object> normalizationParams = readOptionalMap(RRFProcessor.TYPE, tag, normalizationClause, PARAMETERS);
int rankConstant = (int) normalizationParams.getOrDefault("rank_constant", DEFAULT_RANK_CONSTANT);
ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD;
if (Objects.nonNull(normalizationClause)) {
String normalizationTechniqueName = readStringProperty(
RRFProcessor.TYPE,
tag,
normalizationClause,
TECHNIQUE,
RRFNormalizationTechnique.TECHNIQUE_NAME
);
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName);
}

Map<String, Object> combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

// assign defaults
ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization(
RRFNormalizationTechnique.TECHNIQUE_NAME
);
ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination(
RRFScoreCombinationTechnique.TECHNIQUE_NAME
);
Map<String, Object> combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE);
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = readStringProperty(
RRFProcessor.TYPE,
Expand All @@ -80,8 +65,9 @@ public SearchPhaseResultsProcessor create(
RRFScoreCombinationTechnique.TECHNIQUE_NAME
);
// check for optional combination params
Map<String, Object> combinationParams = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS);
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
Map<String, Object> params = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS);
normalizationTechnique = scoreNormalizationFactory.createNormalization(RRFNormalizationTechnique.TECHNIQUE_NAME, params);
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique);
}
log.info(
"Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Locale;
import java.util.Set;

import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
Expand All @@ -29,7 +30,7 @@ public class RRFNormalizationTechnique implements ScoreNormalizationTechnique {
public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT);

final int rankConstant;
private final int rankConstant;

public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNormalizationUtil scoreNormalizationUtil) {
scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS);
Expand All @@ -55,9 +56,9 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int index = 0; index < topDocsPerSubQuery.size(); index++) {
int docsCountPerSubQuery = topDocsPerSubQuery.get(index).scoreDocs.length;
ScoreDoc[] scoreDocs = topDocsPerSubQuery.get(index).scoreDocs;
for (TopDocs topDocs : topDocsPerSubQuery) {
int docsCountPerSubQuery = topDocs.scoreDocs.length;
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
for (int j = 0; j < docsCountPerSubQuery; j++) {
scoreDocs[j].score = (1.f / (float) (rankConstant + j + 1));
}
Expand All @@ -66,12 +67,12 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
}

private int getRankConstant(final Map<String, Object> params) {
if (params.containsKey(PARAM_NAME_RANK_CONSTANT)) {
int rankConstant = (int) params.get(PARAM_NAME_RANK_CONSTANT);
validateRankConstant(rankConstant);
return rankConstant;
if (!params.containsKey(PARAM_NAME_RANK_CONSTANT)) {
return DEFAULT_RANK_CONSTANT;
}
return DEFAULT_RANK_CONSTANT;
int rankConstant = getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT);
validateRankConstant(rankConstant);
return rankConstant;
}

private void validateRankConstant(final int rankConstant) {
Expand All @@ -86,4 +87,14 @@ private void validateRankConstant(final int rankConstant) {
);
}
}

public static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
try {
return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName)));
} catch (NumberFormatException e) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName)
);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.factory;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.RRFProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;

import static org.mockito.Mockito.mock;

import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.PARAMETERS;
import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.COMBINATION_CLAUSE;
import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.TECHNIQUE;

public class RRFProcessorFactoryTests extends OpenSearchTestCase {

@SneakyThrows
public void testCombinationParams_whenValidValues_thenSuccessful() {
RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
Map<String, Object> config = new HashMap<>();
config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100)))));
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create(
processorFactories,
tag,
description,
ignoreFailure,
config,
pipelineContext
);
assertNotNull(searchPhaseResultsProcessor);
assertTrue(searchPhaseResultsProcessor instanceof RRFProcessor);
RRFProcessor rrfProcessor = (RRFProcessor) searchPhaseResultsProcessor;
assertEquals("score-ranker-processor", rrfProcessor.getType());
}

@SneakyThrows
public void testInvalidCombinationParams_whenRankIsNegative_thenFail() {
RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;

Map<String, Object> config = new HashMap<>();
config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", -1)))));
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext)
);
assertTrue(
exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: -1")
);
}

@SneakyThrows
public void testInvalidCombinationParams_whenRankIsTooLarge_thenFail() {
RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;

Map<String, Object> config = new HashMap<>();
config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 50_000)))));
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext)
);
assertTrue(
exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: 50000")
);
}

@SneakyThrows
public void testInvalidCombinationParams_whenRankIsNotNumeric_thenFail() {
RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;

Map<String, Object> config = new HashMap<>();
config.put(
COMBINATION_CLAUSE,
new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", "string"))))
);
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext)
);
assertTrue(exception.getMessage().contains("parameter [rank_constant] must be an integer"));
}

@SneakyThrows
public void testInvalidCombinationName_whenUnsupportedFunction_thenFail() {
RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;

Map<String, Object> config = new HashMap<>();
config.put(
COMBINATION_CLAUSE,
new HashMap<>(Map.of(TECHNIQUE, "my_function", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))
);
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext)
);
assertTrue(exception.getMessage().contains("provided combination technique is not supported"));
}
}

0 comments on commit 2734488

Please sign in to comment.