From b867e69efb485cf898c2bb489c86711d4a584a2c Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 3 Aug 2023 17:54:11 +0200 Subject: [PATCH] Add geometric mean normalization for scores (#239) * Add geometric mean normalization for scores Signed-off-by: Martin Gaievski --- ...eometricMeanScoreCombinationTechnique.java | 54 +++ .../combination/ScoreCombinationFactory.java | 4 +- .../opensearch/neuralsearch/TestUtils.java | 98 ++++- .../common/BaseNeuralSearchIT.java | 55 +++ ...nIT.java => NormalizationProcessorIT.java} | 304 +------------ .../processor/ScoreCombinationIT.java | 412 ++++++++++++++++++ .../processor/ScoreNormalizationIT.java | 295 +++++++++++++ .../processor/TextEmbeddingProcessorIT.java | 14 + ...ticMeanScoreCombinationTechniqueTests.java | 37 +- .../BaseScoreCombinationTechniqueTests.java | 47 +- ...ricMeanScoreCombinationTechniqueTests.java | 98 +++++ ...nicMeanScoreCombinationTechniqueTests.java | 33 +- .../ScoreCombinationFactoryTests.java | 8 + .../neuralsearch/query/HybridQueryIT.java | 70 +-- .../neuralsearch/query/NeuralQueryIT.java | 44 +- 15 files changed, 1187 insertions(+), 386 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java rename src/test/java/org/opensearch/neuralsearch/processor/{ScoreNormalizationCombinationIT.java => NormalizationProcessorIT.java} (57%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..c17d641cc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Abstracts combination of scores based on geometrical mean method + */ +public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + + public static final String TECHNIQUE_NAME = "geometric_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; + + public GeometricMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); + } + + /** + * Weighted geometric mean method for combining scores. + * + * We use formula below to calculate mean. It's based on fact that logarithm of geometric mean is the + * weighted arithmetic mean of the logarithms of individual scores. + * + * geometric_mean = exp(sum(weight_1*ln(score_1) + .... + weight_n*ln(score_n))/sum(weight_1 + ... + weight_n)) + */ + @Override + public float combine(final float[] scores) { + float weightedLnSum = 0; + float sumOfWeights = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; + } + float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); + sumOfWeights += weight; + weightedLnSum += weight * Math.log(score); + } + return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index d034ede16..f05d24823 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -24,7 +24,9 @@ public class ScoreCombinationFactory { ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil), HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), + GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 97a289bcd..ff221bf20 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -7,12 +7,18 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.opensearch.test.OpenSearchTestCase.randomFloat; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; +import org.apache.commons.lang3.Range; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; @@ -20,6 +26,8 @@ public class TestUtils { + private final static String RELATION_EQUAL_TO = "eq"; + /** * Convert an xContentBuilder to a map * @param xContentBuilder to produce map from @@ -58,7 +66,7 @@ public static float[] createRandomVector(int dimension) { } /** - * Assert results of hyrdir query after score normalization and combination + * Assert results of hyrdid query after score normalization and combination * @param querySearchResults collection of query search results after they processed by normalization processor */ public static void assertQueryResultScores(List querySearchResults) { @@ -94,4 +102,92 @@ public static void assertQueryResultScores(List querySearchRe .orElse(Float.MAX_VALUE); assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f); } + + /** + * Assert results of hybrid query after score normalization and combination + * @param searchResponseWithWeightsAsMap collection of query search results after they processed by normalization processor + * @param expectedMaxScore expected maximum score + * @param expectedMaxMinusOneScore second highest expected score + * @param expectedMinScore expected minimal score + */ + public static void assertWeightedScores( + Map searchResponseWithWeightsAsMap, + double expectedMaxScore, + double expectedMaxMinusOneScore, + double expectedMinScore + ) { + assertNotNull(searchResponseWithWeightsAsMap); + Map totalWeights = getTotalHits(searchResponseWithWeightsAsMap); + assertNotNull(totalWeights.get("value")); + assertEquals(4, totalWeights.get("value")); + assertNotNull(totalWeights.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation")); + assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent()); + assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f); + + List scoresWeights = new ArrayList<>(); + for (Map oneHit : getNestedHits(searchResponseWithWeightsAsMap)) { + scoresWeights.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1))); + // verify the scores are normalized with inclusion of weights + assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001); + assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); + assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); + } + + /** + * Assert results of hybrid query after score normalization and combination + * @param searchResponseAsMap collection of query search results after they processed by normalization processor + * @param totalExpectedDocQty expected total document quantity + * @param minMaxScoreRange range of scores from min to max inclusive + */ + public static void assertHybridSearchResults( + Map searchResponseAsMap, + int totalExpectedDocQty, + float[] minMaxScoreRange + ) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range + assertTrue( + Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) + .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) + ); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } + + private static List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private static Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private static Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index add0205b0..02f9f8c9f 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -11,11 +11,13 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -623,4 +625,57 @@ protected void deleteSearchPipeline(final String pipelineId) { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); } + + /** + * Find all modesl that are currently deployed in the cluster + * @return set of model ids + */ + @SneakyThrows + protected Set findDeployedModels() { + + StringBuilder stringBuilderForContentBody = new StringBuilder(); + stringBuilderForContentBody.append("{") + .append("\"query\": { \"match_all\": {} },") + .append(" \"_source\": {") + .append(" \"includes\": [\"model_id\"],") + .append(" \"excludes\": [\"content\", \"model_content\"]") + .append("}}"); + + Response response = makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_search", + null, + toHttpEntity(stringBuilderForContentBody.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + + String responseBody = EntityUtils.toString(response.getEntity()); + + Map models = XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); + Set modelIds = new HashSet<>(); + if (Objects.isNull(models) || models.isEmpty()) { + return modelIds; + } + + Map hits = (Map) models.get("hits"); + List> innerHitsMap = (List>) hits.get("hits"); + return innerHitsMap.stream() + .map(hit -> (Map) hit.get("_source")) + .filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id")) + .map(hitsMap -> (String) hitsMap.get("model_id")) + .collect(Collectors.toSet()); + } + + /** + * Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model + * fail on assertion + * @return id of deployed model + */ + protected String getDeployedModelId() { + Set modelIds = findDeployedModels(); + assertEquals(1, modelIds.size()); + return modelIds.iterator().next(); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java similarity index 57% rename from src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java rename to src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 54271d042..a9b1fc9bf 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -9,13 +9,11 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import lombok.SneakyThrows; @@ -32,12 +30,11 @@ import com.google.common.primitives.Floats; -public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { +public class NormalizationProcessorIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "place"; - private static final String TEST_QUERY_TEXT5 = "welcome"; private static final String TEST_QUERY_TEXT6 = "notexistingword"; private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; private static final String TEST_DOC_TEXT1 = "Hello world"; @@ -49,24 +46,17 @@ public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; - private static final AtomicReference modelId = new AtomicReference<>(); private static final String SEARCH_PIPELINE = "phase-results-pipeline"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private final float[] testVector4 = createRandomVector(TEST_DIMENSION); private final static String RELATION_EQUAL_TO = "eq"; - private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; - - private static final String L2_NORMALIZATION_METHOD = "l2"; - private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; - private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; @Before public void setUp() throws Exception { super.setUp(); - updateClusterSettings(); - modelId.compareAndSet(null, prepareModel()); + prepareModel(); } @After @@ -74,16 +64,7 @@ public void setUp() throws Exception { public void tearDown() { super.tearDown(); deleteSearchPipeline(SEARCH_PIPELINE); - } - - @Override - public boolean isUpdateClusterSettings() { - return false; - } - - @Override - protected boolean preserveClusterUponCompletion() { - return true; + findDeployedModels().forEach(this::deleteModel); } /** @@ -108,8 +89,9 @@ protected boolean preserveClusterUponCompletion() { public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -142,8 +124,9 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); + String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -164,9 +147,10 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + String modelId = getDeployedModelId(); int totalExpectedDocQty = 6; - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 6, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 6, null, null); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -250,220 +234,6 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf assertQueryResults(searchResponseAsMap, 4, true); } - /** - * Using search pipelines with result processor configs like below: - * { - * "description": "Post processor for hybrid search", - * "phase_results_processors": [ - * { - * "normalization-processor": { - * "normalization": { - * "technique": "min-max" - * }, - * "combination": { - * "technique": "arithmetic_mean", - * "parameters": { - * "weights": [ - * 0.4, 0.7 - * ] - * } - * } - * } - * } - * ] - * } - */ - @SneakyThrows - public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); - // check case when number of weights and sub-queries are same - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) - ); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); - - Map searchResponseWithWeights1AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); - - // delete existing pipeline and create a new one with another set of weights - deleteSearchPipeline(SEARCH_PIPELINE); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) - ); - - Map searchResponseWithWeights2AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); - - // check case when number of weights is less than number of sub-queries - // delete existing pipeline and create a new one with another set of weights - deleteSearchPipeline(SEARCH_PIPELINE); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) - ); - - Map searchResponseWithWeights3AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); - - // check case when number of weights is more than number of sub-queries - // delete existing pipeline and create a new one with another set of weights - deleteSearchPipeline(SEARCH_PIPELINE); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) - ); - - Map searchResponseWithWeights4AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); - } - - /** - * Using search pipelines with config for l2 norm: - * { - * "description": "Post processor for hybrid search", - * "phase_results_processors": [ - * { - * "normalization-processor": { - * "normalization": { - * "technique": "l2" - * }, - * "combination": { - * "technique": "arithmetic_mean" - * } - * } - * } - * ] - * } - */ - @SneakyThrows - public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - L2_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) - ); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - hybridQueryBuilder.add(termQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - int totalExpectedDocQty = 5; - float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; - assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); - } - - @SneakyThrows - public void testMinMaxNormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) - ); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - hybridQueryBuilder.add(termQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - int totalExpectedDocQty = 5; - float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; - assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); - } - - @SneakyThrows - public void testL2NormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - L2_NORMALIZATION_METHOD, - HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) - ); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - hybridQueryBuilder.add(termQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - int totalExpectedDocQty = 5; - float[] minMaxExpectedScoresRange = { 0.5f, 1.0f }; - assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); - } - private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { prepareKnnIndex( @@ -616,60 +386,4 @@ private void assertQueryResults(Map searchResponseAsMap, int tot // verify that all ids are unique assertEquals(Set.copyOf(ids).size(), ids.size()); } - - private void assertWeightedScores( - Map searchResponseWithWeightsAsMap, - double expectedMaxScore, - double expectedMaxMinusOneScore, - double expectedMinScore - ) { - assertNotNull(searchResponseWithWeightsAsMap); - Map totalWeights = getTotalHits(searchResponseWithWeightsAsMap); - assertNotNull(totalWeights.get("value")); - assertEquals(4, totalWeights.get("value")); - assertNotNull(totalWeights.get("relation")); - assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation")); - assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent()); - assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f); - - List scoresWeights = new ArrayList<>(); - for (Map oneHit : getNestedHits(searchResponseWithWeightsAsMap)) { - scoresWeights.add((Double) oneHit.get("_score")); - } - // verify scores order - assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1))); - // verify the scores are normalized with inclusion of weights - assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001); - assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); - assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); - } - - private void assertHybridSearchResults(Map searchResponseAsMap, int totalExpectedDocQty, float[] minMaxScoreRange) { - assertNotNull(searchResponseAsMap); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(totalExpectedDocQty, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); - - List> hitsNestedList = getNestedHits(searchResponseAsMap); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hitsNestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - // verify scores order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range - assertTrue( - Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) - .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) - ); - - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java new file mode 100644 index 000000000..e56532b52 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -0,0 +1,412 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreCombinationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + + @Before + public void setUp() throws Exception { + super.setUp(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + findDeployedModels().forEach(this::deleteModel); + } + + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "arithmetic_mean", + * "parameters": { + * "weights": [ + * 0.4, 0.7 + * ] + * } + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + // check case when number of weights and sub-queries are same + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseWithWeights1AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); + + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + ); + + Map searchResponseWithWeights2AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); + + // check case when number of weights is less than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + ); + + Map searchResponseWithWeights3AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); + + // check case when number of weights is more than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + ); + + Map searchResponseWithWeights4AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); + } + + /** + * Using search pipelines with config for harmonic mean: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "harmonic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapDefaultNorm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderDefaultNorm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertHybridSearchResults(searchResponseAsMapDefaultNorm, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapL2Norm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderL2Norm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapL2Norm, 5, new float[] { 0.5f, 1.0f }); + } + + /** + * Using search pipelines with config for geometric mean: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "geometric_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapDefaultNorm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderDefaultNorm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertHybridSearchResults(searchResponseAsMapDefaultNorm, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapL2Norm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderL2Norm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapL2Norm, 5, new float[] { 0.5f, 1.0f }); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java new file mode 100644 index 000000000..64b3fe07f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -0,0 +1,295 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreNormalizationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + findDeployedModels().forEach(this::deleteModel); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + /** + * Using search pipelines with config for l2 norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); + hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapArithmeticMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderArithmeticMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 0.6f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); + hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapHarmonicMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderHarmonicMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapHarmonicMean, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); + hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapGeometricMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderGeometricMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapGeometricMean, 5, new float[] { 0.5f, 1.0f }); + } + + /** + * Using search pipelines with config for min-max norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + String modelId = getDeployedModelId(); + + HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); + hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapArithmeticMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderArithmeticMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 1.0f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); + hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapHarmonicMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderHarmonicMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapHarmonicMean, 5, new float[] { 0.6f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); + hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapGeometricMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderGeometricMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapGeometricMean, 5, new float[] { 0.6f, 1.0f }); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 7cea20a41..de7a70add 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -9,9 +9,12 @@ import java.nio.file.Path; import java.util.Map; +import lombok.SneakyThrows; + import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -25,6 +28,17 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private static final String PIPELINE_NAME = "pipeline-hybrid"; + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); + } + public void testTextEmbeddingProcessor() throws Exception { String modelId = uploadTextEmbeddingModel(); loadModel(modelId); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 3c3ca3776..842df736d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -9,39 +9,60 @@ import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.carrotsearch.randomizedtesting.RandomizedTest; public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + public ArithmeticMeanScoreCombinationTechniqueTests() { this.expectedScoreFunction = this::arithmeticMean; } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } - public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = List.of(0.9, 0.2, 0.7); + public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil ); - testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, -1.0f, 0.6f); List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil + ); + float expectedScore = 0.825f; + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil ); - testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } private float arithmeticMean(List scores, List weights) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java index cf9d1080f..e8ad3532d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java @@ -14,10 +14,13 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.test.OpenSearchTestCase; +import com.carrotsearch.randomizedtesting.RandomizedTest; + @NoArgsConstructor public class BaseScoreCombinationTechniqueTests extends OpenSearchTestCase { protected BiFunction, List, Float> expectedScoreFunction; + protected static final int RANDOM_SCORES_SIZE = 100; private static final float DELTA_FOR_ASSERTION = 0.0001f; @@ -37,9 +40,25 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(fina public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores( final ScoreCombinationTechnique technique, - List weights + List scores, + float expectedScore ) { - float[] scores = { 1.0f, 0.5f, 0.3f }; + float[] scoresArray = new float[scores.size()]; + for (int i = 0; i < scoresArray.length; i++) { + scoresArray[i] = scores.get(i); + } + float actualScore = technique.combine(scoresArray); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + final List weights + ) { + float[] scores = new float[weights.size()]; + for (int i = 0; i < RANDOM_SCORES_SIZE; i++) { + scores[i] = randomScore(); + } float actualScore = technique.combine(scores); float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); @@ -47,11 +66,31 @@ public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores( public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores( final ScoreCombinationTechnique technique, - List weights + List scores, + float expectedScore ) { - float[] scores = { 1.0f, -1.0f, 0.6f }; + float[] scoresArray = new float[scores.size()]; + for (int i = 0; i < scoresArray.length; i++) { + scoresArray[i] = scores.get(i); + } + float actualScore = technique.combine(scoresArray); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + final List weights + ) { + float[] scores = new float[weights.size()]; + for (int i = 0; i < RANDOM_SCORES_SIZE; i++) { + scores[i] = randomScore(); + } float actualScore = technique.combine(scores); float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); } + + private float randomScore() { + return RandomizedTest.randomBoolean() ? -1.0f : RandomizedTest.randomFloat(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..fe0d962ca --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public GeometricMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = this::geometricMean; + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, 0.5f, 0.3f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expectedScore = 0.5797f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, -1.0f, 0.6f); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + float expectedScore = 0.7997f; + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + /** + * Verify score correctness by using alternative formula for geometric mean as n-th root of product of weighted scores, + * more details in here https://en.wikipedia.org/wiki/Weighted_geometric_mean + */ + private float geometricMean(List scores, List weights) { + float product = 1.0f; + float sumOfWeights = 0.0f; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.size(); indexOfSubQuery++) { + float score = scores.get(indexOfSubQuery); + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; + } + float weight = weights.get(indexOfSubQuery).floatValue(); + product *= Math.pow(score, weight); + sumOfWeights += weight; + } + return sumOfWeights == 0 ? 0f : (float) Math.pow(product, (float) 1 / sumOfWeights); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 02a8084ef..8187123a1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -9,39 +9,60 @@ import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.carrotsearch.randomizedtesting.RandomizedTest; public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + public HarmonicMeanScoreCombinationTechniqueTests() { this.expectedScoreFunction = (scores, weights) -> harmonicMean(scores, weights); } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, 0.5f, 0.3f); List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil ); - testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + float expecteScore = 0.4954f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List scores = List.of(1.0f, -1.0f, 0.6f); List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil + ); + float expectedScore = 0.7741f; + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); + } + + public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = IntStream.range(0, RANDOM_SCORES_SIZE) + .mapToObj(i -> RandomizedTest.randomDouble()) + .collect(Collectors.toList()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil ); - testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } private float harmonicMean(List scores, List weights) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java index 9f164b3d3..ee46bf0a1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -27,6 +27,14 @@ public void testHarmonicWeightedMean_whenCreatingByName_thenReturnCorrectInstanc assertTrue(scoreCombinationTechnique instanceof HarmonicMeanScoreCombinationTechnique); } + public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("geometric_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 3414210b9..59c90f495 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -14,18 +14,16 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import java.util.stream.IntStream; import lombok.SneakyThrows; +import org.junit.After; import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; -import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; import com.google.common.primitives.Floats; @@ -46,24 +44,30 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; - private static final String TEST_TEXT_FIELD_NAME_2 = "test-text-field-2"; - private static final String TEST_TEXT_FIELD_NAME_3 = "test-text-field-3"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; - private static final AtomicReference modelId = new AtomicReference<>(); - private static final float EXPECTED_SCORE_BM25 = 0.287682082504034f; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private final static String RELATION_EQUAL_TO = "eq"; - private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; @Before public void setUp() throws Exception { super.setUp(); updateClusterSettings(); - modelId.compareAndSet(null, prepareModel()); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); } @Override @@ -99,8 +103,9 @@ protected boolean preserveClusterUponCompletion() { @SneakyThrows public void testBasicQuery_whenOneSubQuery_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(neuralQueryBuilder); @@ -130,46 +135,6 @@ public void testBasicQuery_whenOneSubQuery_thenSuccessful() { assertTrue(getMaxScore(searchResponseAsMap1).isPresent()); } - @SneakyThrows - public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect() { - initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - modelId.get(), - 3, - null, - null - ); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); - hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); - hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder); - - Map searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3); - - assertEquals(3, getHitCount(searchResponseAsMap)); - - List> hitsNestedList = getNestedHits(searchResponseAsMap); - List scores = new ArrayList<>(); - for (Map oneHit : hitsNestedList) { - scores.add((Double) oneHit.get("_score")); - } - - List expectedScores = List.of( - computeExpectedScore(modelId.get(), testVector1, TEST_SPACE_TYPE, TEST_QUERY_TEXT), - computeExpectedScore(modelId.get(), testVector2, TEST_SPACE_TYPE, TEST_QUERY_TEXT), - computeExpectedScore(modelId.get(), testVector3, TEST_SPACE_TYPE, TEST_QUERY_TEXT) - ); - List actualScores = scores.stream().map(TestUtils::objectToFloat).collect(Collectors.toList()); - assertTrue(expectedScores.containsAll(actualScores)); - - float expectedMaxScore = Math.max(expectedScores.stream().max(Float::compareTo).get(), EXPECTED_SCORE_BM25); - assertEquals(expectedMaxScore, getMaxScore(searchResponseAsMap).get(), 0.001f); - } - /** * Tests complex query with multiple nested sub-queries: * { @@ -268,8 +233,9 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { @SneakyThrows public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); @@ -326,7 +292,7 @@ public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { } @SneakyThrows - public void test_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { + public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index ab3ecb4dd..afe3226c8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -11,7 +11,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; import lombok.SneakyThrows; @@ -39,13 +38,13 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; - private static final AtomicReference modelId = new AtomicReference<>(); private final float[] testVector = createRandomVector(TEST_DIMENSION); @Before public void setUp() throws Exception { super.setUp(); - modelId.compareAndSet(modelId.get(), prepareModel()); + updateClusterSettings(); + prepareModel(); } @After @@ -56,7 +55,7 @@ public void tearDown() { * this happens in case we leave model from previous test case. We use new model for every test, and old model * can be undeployed and deleted to free resources after each test case execution. */ - deleteModel(modelId.get()); + findDeployedModels().forEach(this::deleteModel); } /** @@ -76,10 +75,11 @@ public void tearDown() { @SneakyThrows public void testBasicQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -88,7 +88,7 @@ public void testBasicQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -110,10 +110,11 @@ public void testBasicQuery() { @SneakyThrows public void testBoostQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -125,7 +126,7 @@ public void testBoostQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -152,11 +153,12 @@ public void testBoostQuery() { @SneakyThrows public void testRescoreQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); NeuralQueryBuilder rescoreNeuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -166,7 +168,7 @@ public void testRescoreQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -198,12 +200,13 @@ public void testRescoreQuery() { @SneakyThrows public void testBooleanQuery_withMultipleNeuralQueries() { initializeIndexIfNotExist(TEST_MULTI_VECTOR_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -211,7 +214,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_2, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -223,7 +226,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -253,12 +256,13 @@ public void testBooleanQuery_withMultipleNeuralQueries() { @SneakyThrows public void testBooleanQuery_withNeuralAndBM25Queries() { initializeIndexIfNotExist(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -272,7 +276,7 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float minExpectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float minExpectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); } @@ -297,11 +301,12 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { @SneakyThrows public void testNestedQuery() { initializeIndexIfNotExist(TEST_NESTED_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_NESTED, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, null @@ -311,7 +316,7 @@ public void testNestedQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } @@ -339,10 +344,11 @@ public void testNestedQuery() { @SneakyThrows public void testFilterQuery() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + String modelId = getDeployedModelId(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, - modelId.get(), + modelId, 1, null, new MatchQueryBuilder("_id", "3") @@ -351,7 +357,7 @@ public void testFilterQuery() { assertEquals(1, getHitCount(searchResponseAsMap)); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("3", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); }