Skip to content

Commit

Permalink
Add geometric mean normalization for scores
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 2, 2023
1 parent 1f67b94 commit a62afaa
Show file tree
Hide file tree
Showing 14 changed files with 1,067 additions and 376 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Abstracts combination of scores based on geometrical mean method
*/
public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "geometric_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
* Weighted geometric mean method for combining scores.
*
* We use formula below to calculate mean. It's based on fact that logarithm of geometric mean is the
* weighted arithmetic mean of the logarithms of individual scores.
*
* geometric_mean = exp(sum(weight_1*ln(score_1) + .... + weight_n*ln(score_n))/sum(weight_1 + ... + weight_n))
*/
@Override
public float combine(final float[] scores) {
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score <= 0) {
// scores 0.0 need to be skipped, ln() of 0 is not defined
continue;
}
float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
sumOfWeights += weight;
weightedLnSum += weight * Math.log(score);
}
return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ public class ScoreCombinationFactory {
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil),
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil),
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
98 changes: 97 additions & 1 deletion src/test/java/org/opensearch/neuralsearch/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.opensearch.test.OpenSearchTestCase.randomFloat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.search.query.QuerySearchResult;

public class TestUtils {

private final static String RELATION_EQUAL_TO = "eq";

/**
* Convert an xContentBuilder to a map
* @param xContentBuilder to produce map from
Expand Down Expand Up @@ -58,7 +66,7 @@ public static float[] createRandomVector(int dimension) {
}

/**
* Assert results of hyrdir query after score normalization and combination
* Assert results of hyrdid query after score normalization and combination
* @param querySearchResults collection of query search results after they processed by normalization processor
*/
public static void assertQueryResultScores(List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -94,4 +102,92 @@ public static void assertQueryResultScores(List<QuerySearchResult> querySearchRe
.orElse(Float.MAX_VALUE);
assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f);
}

/**
* Assert results of hybrid query after score normalization and combination
* @param searchResponseWithWeightsAsMap collection of query search results after they processed by normalization processor
* @param expectedMaxScore expected maximum score
* @param expectedMaxMinusOneScore second highest expected score
* @param expectedMinScore expected minimal score
*/
public static void assertWeightedScores(
Map<String, Object> searchResponseWithWeightsAsMap,
double expectedMaxScore,
double expectedMaxMinusOneScore,
double expectedMinScore
) {
assertNotNull(searchResponseWithWeightsAsMap);
Map<String, Object> totalWeights = getTotalHits(searchResponseWithWeightsAsMap);
assertNotNull(totalWeights.get("value"));
assertEquals(4, totalWeights.get("value"));
assertNotNull(totalWeights.get("relation"));
assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation"));
assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent());
assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f);

List<Double> scoresWeights = new ArrayList<>();
for (Map<String, Object> oneHit : getNestedHits(searchResponseWithWeightsAsMap)) {
scoresWeights.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1)));
// verify the scores are normalized with inclusion of weights
assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001);
assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001);
assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001);
}

/**
* Assert results of hybrid query after score normalization and combination
* @param searchResponseAsMap collection of query search results after they processed by normalization processor
* @param totalExpectedDocQty expected total document quantity
* @param minMaxScoreRange range of scores from min to max inclusive
*/
public static void assertHybridSearchResults(
Map<String, Object> searchResponseAsMap,
int totalExpectedDocQty,
float[] minMaxScoreRange
) {
assertNotNull(searchResponseAsMap);
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(totalExpectedDocQty, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get()));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
// verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range
assertTrue(
Range.between(minMaxScoreRange[0], minMaxScoreRange[1])
.contains(scores.stream().map(Double::floatValue).max(Double::compare).get())
);

// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private static List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
}

private static Map<String, Object> getTotalHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (Map<String, Object>) hitsMap.get("total");
}

private static Optional<Float> getMaxScore(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -623,4 +625,60 @@ protected void deleteSearchPipeline(final String pipelineId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

/**
* Find all modesl that are currently deployed in the cluster
* @return set of model ids
*/
@SneakyThrows
protected Set<String> findDeployedModels() {

StringBuilder stringBuilderForContentBody = new StringBuilder();
stringBuilderForContentBody.append("{")
.append("\"query\": {\n" + " \"match_all\": {}\n" + " },")
.append(
" \"_source\": {\n"
+ " \"includes\": [\"model_id\"],\n"
+ " \"excludes\": [\"content\", \"model_content\"]\n"
+ " }"
)
.append("}");

Response response = makeRequest(
client(),
"POST",
"/_plugins/_ml/models/_search",
null,
toHttpEntity(stringBuilderForContentBody.toString()),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

String responseBody = EntityUtils.toString(response.getEntity());

Map<String, Object> models = XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false);
Set<String> modelIds = new HashSet<>();
if (Objects.isNull(models) || models.isEmpty()) {
return modelIds;
}

Map<String, Object> hits = (Map<String, Object>) models.get("hits");
List<Map<String, Object>> innerHitsMap = (List<Map<String, Object>>) hits.get("hits");
return innerHitsMap.stream()
.map(hit -> (Map<String, Object>) hit.get("_source"))
.filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id"))
.map(hitsMap -> (String) hitsMap.get("model_id"))
.collect(Collectors.toSet());
}

/**
* Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model
* fail on assertion
* @return id of deployed model
*/
protected String getDeployedModelId() {
Set<String> modelIds = findDeployedModels();
assertEquals(1, modelIds.size());
return modelIds.iterator().next();
}

}
Loading

0 comments on commit a62afaa

Please sign in to comment.