Skip to content

Commit

Permalink
Adding hybrid score collection and unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed May 19, 2023
1 parent afd4c7c commit 8624b19
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ public class HybridQueryScorer extends Scorer {

private final DocIdSetIterator approximation;

float[] subScores;

Map<Query, Integer> queryToIndex;

HybridQueryScorer(Weight weight, Scorer[] subScorers) throws IOException {
super(weight);
this.subScorers = subScorers;
queryToIndex = new HashMap<>();
subScores = new float[subScorers.length];
int idx = 0;
int size = 0;
for (Scorer scorer : subScorers) {
Expand All @@ -67,12 +70,14 @@ public class HybridQueryScorer extends Scorer {
public float score() throws IOException {
float scoreMax = 0;
double otherScoreSum = 0;
subScores = new float[subScores.length];
for (DisiWrapper w : subScorersPQ) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (w.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float subScore = w.scorer.score();
subScores[queryToIndex.get(w.scorer.getWeight().getQuery())] = subScore;
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
Expand Down Expand Up @@ -121,4 +126,23 @@ public float getMaxScore(int upTo) throws IOException {
public int docID() {
return subScorersPQ.top().doc;
}

/**
* Return array of scores per sub-query for doc id that is defined by current iterator position
* @return
* @throws IOException
*/
public float[] hybridScores() throws IOException {
float[] scores = new float[subScores.length];
DisiWrapper topList = subScorersPQ.topList();
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float subScore = disiWrapper.scorer.score();
scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore;
}
return scores;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.query;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -205,7 +206,8 @@ public void testDoToQuery_whenTooManySubqueries_thenFail() throws Exception {
);
contentParser.nextToken();

expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser));
ParsingException exception = expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser));
assertThat(exception.getMessage(), containsString("Number of sub-queries exceeds maximum supported"));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.query;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.Mockito.mock;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import lombok.SneakyThrows;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class HybridQueryScorerTests extends LuceneTestCase {

@SneakyThrows
public void testWithRandomDocuments_whenOneSubScorer_thenReturnSuccessfully() {
final int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
final Set<Integer> uniqueDocs = new HashSet<>();
while (uniqueDocs.size() < numDocs) {
uniqueDocs.add(random().nextInt(maxDocId));
}
int i = 0;
for (int doc : uniqueDocs) {
docs[i++] = doc;
}
Arrays.sort(docs);
final float[] scores = new float[numDocs];
for (int j = 0; j < numDocs; ++j) {
scores[j] = random().nextFloat();
}

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
new Scorer[] { scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())) }
);
int doc = -1;
int numOfActualDocs = 0;
while (doc != NO_MORE_DOCS) {
int target = doc + 1;
doc = hybridQueryScorer.iterator().nextDoc();
int idx = Arrays.binarySearch(docs, target);
idx = (idx >= 0) ? idx : (-1 - idx);
if (idx == docs.length) {
assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc);
} else {
assertEquals(docs[idx], doc);
assertEquals(scores[idx], hybridQueryScorer.score(), 0f);
numOfActualDocs++;
}
}
assertEquals(numDocs, numOfActualDocs);
}

@SneakyThrows
public void testWithRandomDocuments_whenMultipleScorers_thenReturnSuccessfully() {
final int maxDocId1 = TestUtil.nextInt(random(), 2, 10_000);
final int numDocs1 = RandomizedTest.randomIntBetween(1, maxDocId1 / 2);
final int[] docs1 = new int[numDocs1];
final Set<Integer> uniqueDocs1 = new HashSet<>();
while (uniqueDocs1.size() < numDocs1) {
uniqueDocs1.add(random().nextInt(maxDocId1));
}
int i = 0;
for (int doc : uniqueDocs1) {
docs1[i++] = doc;
}
Arrays.sort(docs1);
final float[] scores1 = new float[numDocs1];
for (int j = 0; j < numDocs1; ++j) {
scores1[j] = random().nextFloat();
}

final int maxDocId2 = TestUtil.nextInt(random(), 2, 10_000);
final int numDocs2 = RandomizedTest.randomIntBetween(1, maxDocId2 / 2);
final int[] docs2 = new int[numDocs2];
final Set<Integer> uniqueDocs2 = new HashSet<>();
while (uniqueDocs2.size() < numDocs2) {
uniqueDocs2.add(random().nextInt(maxDocId2));
}
i = 0;
for (int doc : uniqueDocs2) {
docs2[i++] = doc;
}
Arrays.sort(docs2);
final float[] scores2 = new float[numDocs2];
for (int j = 0; j < numDocs2; ++j) {
scores2[j] = random().nextFloat();
}

Set<Integer> uniqueDocsAll = new HashSet<>();
uniqueDocsAll.addAll(uniqueDocs1);
uniqueDocsAll.addAll(uniqueDocs2);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
mock(Weight.class),
new Scorer[] {
scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())),
scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery())) }
);
int doc = -1;
int numOfActualDocs = 0;
while (doc != NO_MORE_DOCS) {
doc = hybridQueryScorer.iterator().nextDoc();
if (doc == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float[] actualTotalScores = hybridQueryScorer.hybridScores();
float actualTotalScore = 0.0f;
for (float score : actualTotalScores) {
actualTotalScore += score;
}
float expectedScore = 0.0f;
if (uniqueDocs1.contains(doc)) {
int idx = Arrays.binarySearch(docs1, doc);
expectedScore += scores1[idx];
}
if (uniqueDocs2.contains(doc)) {
int idx = Arrays.binarySearch(docs2, doc);
expectedScore += scores2[idx];
}
assertEquals(expectedScore, actualTotalScore, 0.001f);
numOfActualDocs++;
}
assertEquals(uniqueDocsAll.size(), numOfActualDocs);
}

@SneakyThrows
public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenReturnSuccessfully() {
final int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
final Set<Integer> uniqueDocs = new HashSet<>();
while (uniqueDocs.size() < numDocs) {
uniqueDocs.add(random().nextInt(maxDocId));
}
int i = 0;
for (int doc : uniqueDocs) {
docs[i++] = doc;
}
Arrays.sort(docs);
final float[] scores = new float[numDocs];
for (int j = 0; j < numDocs; ++j) {
scores[j] = random().nextFloat();
}

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
new Scorer[] { null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null }
);
int doc = -1;
int numOfActualDocs = 0;
while (doc != NO_MORE_DOCS) {
int target = doc + 1;
doc = hybridQueryScorer.iterator().nextDoc();
int idx = Arrays.binarySearch(docs, target);
idx = (idx >= 0) ? idx : (-1 - idx);
if (idx == docs.length) {
assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc);
} else {
assertEquals(docs[idx], doc);
assertEquals(scores[idx], hybridQueryScorer.score(), 0f);
numOfActualDocs++;
}
}
assertEquals(numDocs, numOfActualDocs);
}

private static Weight fakeWeight(Query query) {
return new Weight(query) {

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return null;
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}

private static DocIdSetIterator iterator(final int... docs) {
return new DocIdSetIterator() {

int i = -1;

@Override
public int nextDoc() throws IOException {
if (i + 1 == docs.length) {
return NO_MORE_DOCS;
} else {
return docs[++i];
}
}

@Override
public int docID() {
return i < 0 ? -1 : i == docs.length ? NO_MORE_DOCS : docs[i];
}

@Override
public long cost() {
return docs.length;
}

@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
};
}

private static Scorer scorer(final int[] docs, final float[] scores, Weight weight) {
final DocIdSetIterator iterator = iterator(docs);
return new Scorer(weight) {

int lastScoredDoc = -1;

public DocIdSetIterator iterator() {
return iterator;
}

@Override
public int docID() {
return iterator.docID();
}

@Override
public float score() throws IOException {
assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID());
lastScoredDoc = docID();
final int idx = Arrays.binarySearch(docs, docID());
return scores[idx];
}

@Override
public float getMaxScore(int upTo) throws IOException {
return Float.MAX_VALUE;
}
};
}
}

0 comments on commit 8624b19

Please sign in to comment.