diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java index cb5d3372f..7481267a2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java @@ -31,11 +31,9 @@ @AllArgsConstructor @Log4j2 public class RRFProcessorFactory implements Processor.Factory { - public static final String NORMALIZATION_CLAUSE = "combination"; public static final String COMBINATION_CLAUSE = "combination"; public static final String TECHNIQUE = "technique"; public static final String PARAMETERS = "parameters"; - public static final int DEFAULT_RANK_CONSTANT = 60; private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; private ScoreNormalizationFactory scoreNormalizationFactory; @@ -50,27 +48,14 @@ public SearchPhaseResultsProcessor create( final Map config, final Processor.PipelineContext pipelineContext ) throws Exception { - Map normalizationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE); - // reads parameters passed in from user to get rank constant to be used in RRFNormalizationTechnique - Map normalizationParams = readOptionalMap(RRFProcessor.TYPE, tag, normalizationClause, PARAMETERS); - int rankConstant = (int) normalizationParams.getOrDefault("rank_constant", DEFAULT_RANK_CONSTANT); - ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD; - if (Objects.nonNull(normalizationClause)) { - String normalizationTechniqueName = readStringProperty( - RRFProcessor.TYPE, - tag, - normalizationClause, - TECHNIQUE, - RRFNormalizationTechnique.TECHNIQUE_NAME - ); - normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName); - } - - Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); - + // assign defaults + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization( + RRFNormalizationTechnique.TECHNIQUE_NAME + ); ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination( RRFScoreCombinationTechnique.TECHNIQUE_NAME ); + Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); if (Objects.nonNull(combinationClause)) { String combinationTechnique = readStringProperty( RRFProcessor.TYPE, @@ -80,8 +65,9 @@ public SearchPhaseResultsProcessor create( RRFScoreCombinationTechnique.TECHNIQUE_NAME ); // check for optional combination params - Map combinationParams = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS); - scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams); + Map params = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS); + normalizationTechnique = scoreNormalizationFactory.createNormalization(RRFNormalizationTechnique.TECHNIQUE_NAME, params); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); } log.info( "Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]", diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index e677bd3a7..d9f55019c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -10,6 +10,7 @@ import java.util.Locale; import java.util.Set; +import org.apache.commons.lang3.math.NumberUtils; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; @@ -29,7 +30,7 @@ public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT); - final int rankConstant; + private final int rankConstant; public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); @@ -55,9 +56,9 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { continue; } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - for (int index = 0; index < topDocsPerSubQuery.size(); index++) { - int docsCountPerSubQuery = topDocsPerSubQuery.get(index).scoreDocs.length; - ScoreDoc[] scoreDocs = topDocsPerSubQuery.get(index).scoreDocs; + for (TopDocs topDocs : topDocsPerSubQuery) { + int docsCountPerSubQuery = topDocs.scoreDocs.length; + ScoreDoc[] scoreDocs = topDocs.scoreDocs; for (int j = 0; j < docsCountPerSubQuery; j++) { scoreDocs[j].score = (1.f / (float) (rankConstant + j + 1)); } @@ -66,12 +67,12 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { } private int getRankConstant(final Map params) { - if (params.containsKey(PARAM_NAME_RANK_CONSTANT)) { - int rankConstant = (int) params.get(PARAM_NAME_RANK_CONSTANT); - validateRankConstant(rankConstant); - return rankConstant; + if (!params.containsKey(PARAM_NAME_RANK_CONSTANT)) { + return DEFAULT_RANK_CONSTANT; } - return DEFAULT_RANK_CONSTANT; + int rankConstant = getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT); + validateRankConstant(rankConstant); + return rankConstant; } private void validateRankConstant(final int rankConstant) { @@ -86,4 +87,14 @@ private void validateRankConstant(final int rankConstant) { ); } } + + public static int getParamAsInteger(final Map parameters, final String fieldName) { + try { + return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName) + ); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java new file mode 100644 index 000000000..03f56fc82 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.TECHNIQUE; + +public class RRFProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testCombinationParams_whenValidValues_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof RRFProcessor); + RRFProcessor rrfProcessor = (RRFProcessor) searchPhaseResultsProcessor; + assertEquals("score-ranker-processor", rrfProcessor.getType()); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNegative_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", -1))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: -1") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsTooLarge_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 50_000))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: 50000") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNotNumeric_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", "string")))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("parameter [rank_constant] must be an integer")); + } + + @SneakyThrows + public void testInvalidCombinationName_whenUnsupportedFunction_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "my_function", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100)))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("provided combination technique is not supported")); + } +}