Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding search processor for score normalization and combination #227

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
Expand Down Expand Up @@ -61,6 +64,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
@VisibleForTesting
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;

@Override
public Collection<Object> createComponents(
Expand All @@ -77,6 +81,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}

Expand Down Expand Up @@ -109,6 +114,6 @@ public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(
Parameters parameters
) {
return Map.of(NormalizationProcessor.TYPE, new NormalizationProcessorFactory());
return Map.of(NormalizationProcessor.TYPE, new NormalizationProcessorFactory(normalizationProcessorWorkflow));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
public class ArithmeticMeanScoreCombinationMethod implements ScoreCombinationMethod {

private static final ArithmeticMeanScoreCombinationMethod INSTANCE = new ArithmeticMeanScoreCombinationMethod();
private static final Float ZERO_SCORE = 0.0f;

public static ArithmeticMeanScoreCombinationMethod getInstance() {
return INSTANCE;
Expand All @@ -36,6 +37,9 @@ public float combine(final float[] scores) {
count++;
}
}
if (count == 0) {
return ZERO_SCORE;
}
return combinedScore / count;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unit test where count is zero.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, I'll add more tests in next PRs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm. I thought that divide by zero will throw an exception but I was wrong. For floating point number, it will be Infinity, -Infinity, or NaN. If those values are okay, I am fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In such case we need to return 0.0, as this is the case when, if there are scores, we're skipping some of them and not increasing count. Let me update the method.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;
Expand All @@ -20,6 +20,8 @@
import org.apache.lucene.search.TotalHits;
import org.opensearch.neuralsearch.search.CompoundTopDocs;

import com.google.common.annotations.VisibleForTesting;

/**
* Abstracts combination of scores in query search results.
*/
Expand Down Expand Up @@ -50,7 +52,8 @@ public List<Float> combineScores(final List<CompoundTopDocs> queryTopDocs, final
.collect(Collectors.toList());
}

private float combineShardScores(
@VisibleForTesting
protected float combineShardScores(
final ScoreCombinationTechnique scoreCombinationTechnique,
final CompoundTopDocs compoundQueryTopDocs
) {
Expand All @@ -68,42 +71,46 @@ private float combineShardScores(
);

// - sort documents by scores and take first "max number" of docs
// create a priority queue of doc ids that are sorted by their combined scores
PriorityQueue<Integer> scoreQueue = getPriorityQueueOfDocIds(combinedNormalizedScoresByDocId);
// store max score to resulting list, call it now as priority queue will change after combining scores
float maxScore = combinedNormalizedScoresByDocId.get(scoreQueue.peek());
// create a collection of doc ids that are sorted by their combined scores
List<Integer> sortedDocsIds = getPriorityQueueOfDocIds(combinedNormalizedScoresByDocId);

// - update query search results with normalized scores
updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, scoreQueue);
return maxScore;
updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, sortedDocsIds);

// return max score
if (sortedDocsIds.isEmpty()) {
return ZERO_SCORE;
}
return combinedNormalizedScoresByDocId.get(sortedDocsIds.get(0));
}

private PriorityQueue<Integer> getPriorityQueueOfDocIds(final Map<Integer, Float> combinedNormalizedScoresByDocId) {
PriorityQueue<Integer> pq = new PriorityQueue<>(
(a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))
);
@VisibleForTesting
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
protected List<Integer> getPriorityQueueOfDocIds(final Map<Integer, Float> combinedNormalizedScoresByDocId) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
// we're merging docs with normalized and combined scores. we need to have only maxHits results
pq.addAll(combinedNormalizedScoresByDocId.keySet());
return pq;
List<Integer> sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet());
sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a)));
return sortedDocsIds;
}

private ScoreDoc[] getCombinedScoreDocs(
@VisibleForTesting
protected ScoreDoc[] getCombinedScoreDocs(
final CompoundTopDocs compoundQueryTopDocs,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final PriorityQueue<Integer> scoreQueue,
final List<Integer> sortedScores,
final int maxHits
) {
ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];

int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex;
for (int j = 0; j < maxHits && !scoreQueue.isEmpty(); j++) {
int docId = scoreQueue.poll();
for (int j = 0; j < maxHits && j < sortedScores.size(); j++) {
int docId = sortedScores.get(j);
finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId);
}
return finalScoreDocs;
}

private Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs> topDocsPerSubQuery) {
@VisibleForTesting
public Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs> topDocsPerSubQuery) {
Map<Integer, float[]> normalizedScoresPerDoc = new HashMap<>();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs topDocs = topDocsPerSubQuery.get(j);
Expand All @@ -120,7 +127,8 @@ private Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs>
return normalizedScoresPerDoc;
}

private Map<Integer, Float> combineScoresAndGetCombinedNormilizedScoresPerDocument(
@VisibleForTesting
protected Map<Integer, Float> combineScoresAndGetCombinedNormilizedScoresPerDocument(
final Map<Integer, float[]> normalizedScoresPerDocument,
final ScoreCombinationTechnique scoreCombinationTechnique
) {
Expand All @@ -129,20 +137,22 @@ private Map<Integer, Float> combineScoresAndGetCombinedNormilizedScoresPerDocume
.collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue())));
}

private void updateQueryTopDocsWithCombinedScores(
@VisibleForTesting
protected void updateQueryTopDocsWithCombinedScores(
final CompoundTopDocs compoundQueryTopDocs,
final List<TopDocs> topDocsPerSubQuery,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final PriorityQueue<Integer> scoreQueue
final List<Integer> sortedScores
) {
// - count max number of hits among sub-queries
int maxHits = getMaxHits(topDocsPerSubQuery);
// - update query search results with normalized scores
compoundQueryTopDocs.scoreDocs = getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, scoreQueue, maxHits);
compoundQueryTopDocs.scoreDocs = getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits);
compoundQueryTopDocs.totalHits = getTotalHits(topDocsPerSubQuery, maxHits);
}

private int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
@VisibleForTesting
protected int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
int maxHits = 0;
for (TopDocs topDocs : topDocsPerSubQuery) {
int hits = topDocs.scoreDocs.length;
Expand All @@ -151,7 +161,8 @@ private int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
return maxHits;
}

private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, int maxHits) {
@VisibleForTesting
protected TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, int maxHits) {
TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,23 @@
import java.util.Map;
import java.util.Objects;

import lombok.AllArgsConstructor;

import org.apache.commons.lang3.EnumUtils;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

/**
* Factory for query results normalization processor for search pipeline. Instantiates processor based on user provided input.
*/
@AllArgsConstructor
public class NormalizationProcessorFactory implements Processor.Factory<SearchPhaseResultsProcessor> {
private final NormalizationProcessorWorkflow normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(
new ScoreNormalizer(),
new ScoreCombiner()
);
private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;

@Override
public SearchPhaseResultsProcessor create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -24,7 +27,9 @@ public class NormalizationProcessorFactoryTests extends OpenSearchTestCase {

@SneakyThrows
public void testNormalizationProcessor_whenNoParams_thenSuccessful() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory();
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
Expand All @@ -47,7 +52,9 @@ public void testNormalizationProcessor_whenNoParams_thenSuccessful() {

@SneakyThrows
public void testNormalizationProcessor_whenWithParams_thenSuccessful() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory();
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
Expand All @@ -71,7 +78,9 @@ public void testNormalizationProcessor_whenWithParams_thenSuccessful() {
}

public void testInputValidation_whenInvalidParameters_thenFail() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory();
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
Expand Down