Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding L2 norm technique #236

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.normalization;

import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.search.CompoundTopDocs;

/**
* Abstracts normalization of scores based on L2 method
*/
public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique {

public static final String TECHNIQUE_NAME = "l2";
private static final float MIN_SCORE = 0.001f;

/**
* L2 normalization method.
* n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2)
* Main algorithm steps:
* - calculate sum of squares of all scores
* - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query
*/
@Override
public void normalize(final List<CompoundTopDocs> queryTopDocs) {
int numOfSubqueries = queryTopDocs.stream()
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0)
.findAny()
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
.get()
.getCompoundTopDocs()
.size();
// get l2 norms for each sub-query
float[] normsPerSubquery = getL2Norm(queryTopDocs, numOfSubqueries);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved

// do normalization using actual score and l2 norm
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery[j]);
}
}
}
}

private float[] getL2Norm(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
float[] l2Norms = new float[numOfSubqueries];
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
IntStream.range(0, topDocsPerSubQuery.size()).forEach(index -> {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) {
l2Norms[index] += scoreDocs.score * scoreDocs.score;
}
});
}
for (int index = 0; index < l2Norms.length; index++) {
l2Norms[index] = (float) Math.sqrt(l2Norms[index]);
}
return l2Norms;
}

private float normalizeSingleScore(final float score, final float l2Norm) {
return l2Norm == 0 ? MIN_SCORE : score / l2Norm;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ public class ScoreNormalizationFactory {

private final Map<String, ScoreNormalizationTechnique> scoreNormalizationMethodsMap = Map.of(
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
new MinMaxScoreNormalizationTechnique()
new MinMaxScoreNormalizationTechnique(),
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
new L2ScoreNormalizationTechnique()
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ protected String uploadModel(String requestBody) throws Exception {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(uploadResponse.getEntity()),
false
);
Expand Down Expand Up @@ -136,7 +136,7 @@ protected void loadModel(String modelId) throws Exception {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(uploadResponse.getEntity()),
false
);
Expand Down Expand Up @@ -185,7 +185,7 @@ protected float[] runInference(String modelId, String queryText) {
);

Map<String, Object> inferenceResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(inferenceResponse.getEntity()),
false
);
Expand Down Expand Up @@ -215,7 +215,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(response.getEntity()),
false
);
Expand All @@ -239,7 +239,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(pipelineCreateResponse.getEntity()),
false
);
Expand Down Expand Up @@ -329,7 +329,7 @@ protected Map<String, Object> search(

String responseBody = EntityUtils.toString(response.getEntity());

return XContentHelper.convertToMap(XContentFactory.xContent(XContentType.JSON), responseBody, false);
return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false);
}

/**
Expand Down Expand Up @@ -445,11 +445,7 @@ protected Map<String, Object> getTaskQueryResponse(String taskId) throws Excepti
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
return XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
EntityUtils.toString(taskQueryResponse.getEntity()),
false
);
return XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(taskQueryResponse.getEntity()), false);
}

protected boolean checkComplete(Map<String, Object> node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.processor.normalization.L2ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

Expand Down Expand Up @@ -93,12 +94,7 @@ protected boolean preserveClusterUponCompletion() {
* "technique": "min-max"
* },
* "combination": {
* "technique": "sum",
* "parameters": {
* "weights": [
* 0.4, 0.7
* ]
* }
* "technique": "arithmetic_mean"
* }
* }
* }
Expand Down Expand Up @@ -251,6 +247,29 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf
assertQueryResults(searchResponseAsMap, 4, true);
}

/**
* Using search pipelines with result processor configs like below:
* {
* "description": "Post processor for hybrid search",
* "phase_results_processors": [
* {
* "normalization-processor": {
* "normalization": {
* "technique": "min-max"
* },
* "combination": {
* "technique": "arithmetic_mean",
* "parameters": {
* "weights": [
* 0.4, 0.7
* ]
* }
* }
* }
* }
* ]
* }
*/
@SneakyThrows
public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
Expand Down Expand Up @@ -337,6 +356,74 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001);
}

/**
* Using search pipelines with config for l2 norm:
* {
* "description": "Post processor for hybrid search",
* "phase_results_processors": [
* {
* "normalization-processor": {
* "normalization": {
* "technique": "l2"
* },
* "combination": {
* "technique": "arithmetic_mean"
* }
* }
* }
* ]
* }
*/
@SneakyThrows
public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME);
createSearchPipeline(
SEARCH_PIPELINE,
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
hybridQueryBuilder.add(termQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);
int totalExpectedDocQty = 5;
assertNotNull(searchResponseAsMap);
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(totalExpectedDocQty, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(Range.between(.6f, 1.0f).contains(getMaxScore(searchResponseAsMap).get()));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
// verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range
assertTrue(Range.between(.6f, 1.0f).contains((float) scores.stream().map(Double::floatValue).max(Double::compare).get()));

// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.message.BasicHeader;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
Expand Down Expand Up @@ -71,7 +70,7 @@ private void ingestDocument() throws Exception {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
);
Map<String, Object> map = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(response.getEntity()),
false
);
Expand Down
Loading