Skip to content

Commit

Permalink
Implement scorer.score as sum of all sub-scores
Browse files Browse the repository at this point in the history
In standard scorer.score implementation return sum of all sub-scores as
one score for doc id.
Fixed unit tests

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed May 22, 2023
1 parent 8624b19 commit 6bd9f8e
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
/**
* Implementation fo Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual
* scores for each sub-query.
*
* @opensearch.internal
*/
public class HybridQuery extends Query implements Iterable<Query> {
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@
/**
* Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and
* collects score for each of those sub-query.
*
* @opensearch.internal
*/
@Log4j2
@Getter
@Setter
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
public class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBuilder> {
public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBuilder> {
public static final String NAME = "hybrid";

private static final ParseField QUERIES_FIELD = new ParseField("queries");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
* order of doc id, this class fills up array of scores per sub-query for each doc id. Order in array of scores
* corresponds to order of sub-queries in an input Hybrid query.
*
* @opensearch.internal
*/
public class HybridQueryScorer extends Scorer {
public final class HybridQueryScorer extends Scorer {

// score for each of sub-query in this hybrid query
@Getter
Expand Down Expand Up @@ -67,25 +65,23 @@ public class HybridQueryScorer extends Scorer {
this.approximation = new DisjunctionDISIApproximation(this.subScorersPQ);
}

/**
* Returns the score of the current document matching the query. Score is a sum of all scores from sub-query scorers.
* @return combined total score of all sub-scores
* @throws IOException
*/
@Override
public float score() throws IOException {
float scoreMax = 0;
double otherScoreSum = 0;
subScores = new float[subScores.length];
for (DisiWrapper w : subScorersPQ) {
DisiWrapper topList = subScorersPQ.topList();
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (w.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float subScore = w.scorer.score();
subScores[queryToIndex.get(w.scorer.getWeight().getQuery())] = subScore;
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
totalScore += disiWrapper.scorer.score();
}
return (float) (scoreMax + otherScoreSum);
return totalScore;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
/**
* Calculates query weights and build query scorers for hybrid query.
*/
class HybridQueryWeight extends Weight {
public final class HybridQueryWeight extends Weight {

private final HybridQuery queries;
// The Weights for our subqueries, in 1-1 correspondence
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

import lombok.SneakyThrows;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
Expand All @@ -32,95 +35,44 @@ public class HybridQueryScorerTests extends LuceneTestCase {

@SneakyThrows
public void testWithRandomDocuments_whenOneSubScorer_thenReturnSuccessfully() {
final int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
final Set<Integer> uniqueDocs = new HashSet<>();
while (uniqueDocs.size() < numDocs) {
uniqueDocs.add(random().nextInt(maxDocId));
}
int i = 0;
for (int doc : uniqueDocs) {
docs[i++] = doc;
}
Arrays.sort(docs);
final float[] scores = new float[numDocs];
for (int j = 0; j < numDocs; ++j) {
scores[j] = random().nextFloat();
}
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
new Scorer[] { scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())) }
new Scorer[] { scorer(docsAndScores.getLeft(), docsAndScores.getRight(), fakeWeight(new MatchAllDocsQuery())) }
);
int doc = -1;
int numOfActualDocs = 0;
while (doc != NO_MORE_DOCS) {
int target = doc + 1;
doc = hybridQueryScorer.iterator().nextDoc();
int idx = Arrays.binarySearch(docs, target);
idx = (idx >= 0) ? idx : (-1 - idx);
if (idx == docs.length) {
assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc);
} else {
assertEquals(docs[idx], doc);
assertEquals(scores[idx], hybridQueryScorer.score(), 0f);
numOfActualDocs++;
}
}
assertEquals(numDocs, numOfActualDocs);

testWithQuery(docs, scores, hybridQueryScorer);
}

@SneakyThrows
public void testWithRandomDocuments_whenMultipleScorers_thenReturnSuccessfully() {
final int maxDocId1 = TestUtil.nextInt(random(), 2, 10_000);
final int numDocs1 = RandomizedTest.randomIntBetween(1, maxDocId1 / 2);
final int[] docs1 = new int[numDocs1];
final Set<Integer> uniqueDocs1 = new HashSet<>();
while (uniqueDocs1.size() < numDocs1) {
uniqueDocs1.add(random().nextInt(maxDocId1));
}
int i = 0;
for (int doc : uniqueDocs1) {
docs1[i++] = doc;
}
Arrays.sort(docs1);
final float[] scores1 = new float[numDocs1];
for (int j = 0; j < numDocs1; ++j) {
scores1[j] = random().nextFloat();
}
public void testWithRandomDocumentsAndHybridScores_whenMultipleScorers_thenReturnSuccessfully() {
int maxDocId1 = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores1 = generateDocuments(maxDocId1);
int[] docs1 = docsAndScores1.getLeft();
float[] scores1 = docsAndScores1.getRight();
int maxDocId2 = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores2 = generateDocuments(maxDocId2);
int[] docs2 = docsAndScores2.getLeft();
float[] scores2 = docsAndScores2.getRight();

final int maxDocId2 = TestUtil.nextInt(random(), 2, 10_000);
final int numDocs2 = RandomizedTest.randomIntBetween(1, maxDocId2 / 2);
final int[] docs2 = new int[numDocs2];
final Set<Integer> uniqueDocs2 = new HashSet<>();
while (uniqueDocs2.size() < numDocs2) {
uniqueDocs2.add(random().nextInt(maxDocId2));
}
i = 0;
for (int doc : uniqueDocs2) {
docs2[i++] = doc;
}
Arrays.sort(docs2);
final float[] scores2 = new float[numDocs2];
for (int j = 0; j < numDocs2; ++j) {
scores2[j] = random().nextFloat();
}

Set<Integer> uniqueDocsAll = new HashSet<>();
uniqueDocsAll.addAll(uniqueDocs1);
uniqueDocsAll.addAll(uniqueDocs2);
Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
mock(Weight.class),
weight,
new Scorer[] {
scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())),
scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery())) }
);
int doc = -1;
int numOfActualDocs = 0;
Set<Integer> uniqueDocs1 = Arrays.stream(docs1).boxed().collect(Collectors.toSet());
Set<Integer> uniqueDocs2 = Arrays.stream(docs2).boxed().collect(Collectors.toSet());
while (doc != NO_MORE_DOCS) {
doc = hybridQueryScorer.iterator().nextDoc();
if (doc == DocIdSetIterator.NO_MORE_DOCS) {
Expand All @@ -143,12 +95,84 @@ public void testWithRandomDocuments_whenMultipleScorers_thenReturnSuccessfully()
assertEquals(expectedScore, actualTotalScore, 0.001f);
numOfActualDocs++;
}
assertEquals(uniqueDocsAll.size(), numOfActualDocs);

int totalUniqueCount = uniqueDocs1.size();
for (int n : uniqueDocs2) {
if (!uniqueDocs1.contains(n)) {
totalUniqueCount++;
}
}
assertEquals(totalUniqueCount, numOfActualDocs);
}

@SneakyThrows
public void testWithRandomDocumentsAndCombinedScore_whenMultipleScorers_thenReturnSuccessfully() {
int maxDocId1 = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores1 = generateDocuments(maxDocId1);
int[] docs1 = docsAndScores1.getLeft();
float[] scores1 = docsAndScores1.getRight();
int maxDocId2 = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores2 = generateDocuments(maxDocId2);
int[] docs2 = docsAndScores2.getLeft();
float[] scores2 = docsAndScores2.getRight();

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
new Scorer[] {
scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())),
scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery())) }
);
int doc = -1;
int numOfActualDocs = 0;
Set<Integer> uniqueDocs1 = Arrays.stream(docs1).boxed().collect(Collectors.toSet());
Set<Integer> uniqueDocs2 = Arrays.stream(docs2).boxed().collect(Collectors.toSet());
while (doc != NO_MORE_DOCS) {
doc = hybridQueryScorer.iterator().nextDoc();
if (doc == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float expectedScore = 0.0f;
if (uniqueDocs1.contains(doc)) {
int idx = Arrays.binarySearch(docs1, doc);
expectedScore += scores1[idx];
}
if (uniqueDocs2.contains(doc)) {
int idx = Arrays.binarySearch(docs2, doc);
expectedScore += scores2[idx];
}
assertEquals(expectedScore, hybridQueryScorer.score(), 0.001f);
numOfActualDocs++;
}

int totalUniqueCount = uniqueDocs1.size();
for (int n : uniqueDocs2) {
if (!uniqueDocs1.contains(n)) {
totalUniqueCount++;
}
}
assertEquals(totalUniqueCount, numOfActualDocs);
}

@SneakyThrows
public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenReturnSuccessfully() {
final int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
new Scorer[] { null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null }
);

testWithQuery(docs, scores, hybridQueryScorer);
}

private Pair<int[], float[]> generateDocuments(int maxDocId) {
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
final Set<Integer> uniqueDocs = new HashSet<>();
Expand All @@ -164,13 +188,10 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR
for (int j = 0; j < numDocs; ++j) {
scores[j] = random().nextFloat();
}
return new ImmutablePair<>(docs, scores);
}

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
new Scorer[] { null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null }
);
private void testWithQuery(int[] docs, float[] scores, HybridQueryScorer hybridQueryScorer) throws IOException {
int doc = -1;
int numOfActualDocs = 0;
while (doc != NO_MORE_DOCS) {
Expand All @@ -186,7 +207,7 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR
numOfActualDocs++;
}
}
assertEquals(numDocs, numOfActualDocs);
assertEquals(docs.length, numOfActualDocs);
}

private static Weight fakeWeight(Query query) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,30 @@ public void testRewrite() throws Exception {
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

IndexReader reader = mock(IndexReader.class);
Directory directory = newDirectory();
final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();

w.addDocument(getDocument(RandomizedTest.randomInt(), RandomizedTest.randomAsciiAlphanumOfLength(8), ft));
w.addDocument(getDocument(RandomizedTest.randomInt(), RandomizedTest.randomAsciiAlphanumOfLength(8), ft));
w.addDocument(getDocument(RandomizedTest.randomInt(), RandomizedTest.randomAsciiAlphanumOfLength(8), ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
HybridQuery query = new HybridQuery(
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
Query rewritten = query.rewrite(reader);
QueryUtils.checkUnequal(query, rewritten);
Query rewritten2 = rewritten.rewrite(reader);
assertSame(rewritten, rewritten2);

w.close();
reader.close();
directory.close();
}

@SneakyThrows
Expand Down

0 comments on commit 6bd9f8e

Please sign in to comment.