diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index f3206728d..367dee674 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -36,12 +36,15 @@ public class HybridQueryScorer extends Scorer { private final DocIdSetIterator approximation; + float[] subScores; + Map queryToIndex; HybridQueryScorer(Weight weight, Scorer[] subScorers) throws IOException { super(weight); this.subScorers = subScorers; queryToIndex = new HashMap<>(); + subScores = new float[subScorers.length]; int idx = 0; int size = 0; for (Scorer scorer : subScorers) { @@ -67,12 +70,14 @@ public class HybridQueryScorer extends Scorer { public float score() throws IOException { float scoreMax = 0; double otherScoreSum = 0; + subScores = new float[subScores.length]; for (DisiWrapper w : subScorersPQ) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue if (w.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { continue; } float subScore = w.scorer.score(); + subScores[queryToIndex.get(w.scorer.getWeight().getQuery())] = subScore; if (subScore >= scoreMax) { otherScoreSum += scoreMax; scoreMax = subScore; @@ -121,4 +126,23 @@ public float getMaxScore(int upTo) throws IOException { public int docID() { return subScorersPQ.top().doc; } + + /** + * Return array of scores per sub-query for doc id that is defined by current iterator position + * @return + * @throws IOException + */ + public float[] hybridScores() throws IOException { + float[] scores = new float[subScores.length]; + DisiWrapper topList = subScorersPQ.topList(); + for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { + // check if this doc has match in the subQuery. If not, add score as 0.0 and continue + if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + continue; + } + float subScore = disiWrapper.scorer.score(); + scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore; + } + return scores; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index fb1874fcd..35ab08aa2 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -205,7 +206,8 @@ public void testDoToQuery_whenTooManySubqueries_thenFail() throws Exception { ); contentParser.nextToken(); - expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + ParsingException exception = expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + assertThat(exception.getMessage(), containsString("Number of sub-queries exceeds maximum supported")); } /** diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java new file mode 100644 index 000000000..285588495 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -0,0 +1,272 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import lombok.SneakyThrows; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +public class HybridQueryScorerTests extends LuceneTestCase { + + @SneakyThrows + public void testWithRandomDocuments_whenOneSubScorer_thenReturnSuccessfully() { + final int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); + final int[] docs = new int[numDocs]; + final Set uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDocId)); + } + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores = new float[numDocs]; + for (int j = 0; j < numDocs; ++j) { + scores[j] = random().nextFloat(); + } + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + new Scorer[] { scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())) } + ); + int doc = -1; + int numOfActualDocs = 0; + while (doc != NO_MORE_DOCS) { + int target = doc + 1; + doc = hybridQueryScorer.iterator().nextDoc(); + int idx = Arrays.binarySearch(docs, target); + idx = (idx >= 0) ? idx : (-1 - idx); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores[idx], hybridQueryScorer.score(), 0f); + numOfActualDocs++; + } + } + assertEquals(numDocs, numOfActualDocs); + } + + @SneakyThrows + public void testWithRandomDocuments_whenMultipleScorers_thenReturnSuccessfully() { + final int maxDocId1 = TestUtil.nextInt(random(), 2, 10_000); + final int numDocs1 = RandomizedTest.randomIntBetween(1, maxDocId1 / 2); + final int[] docs1 = new int[numDocs1]; + final Set uniqueDocs1 = new HashSet<>(); + while (uniqueDocs1.size() < numDocs1) { + uniqueDocs1.add(random().nextInt(maxDocId1)); + } + int i = 0; + for (int doc : uniqueDocs1) { + docs1[i++] = doc; + } + Arrays.sort(docs1); + final float[] scores1 = new float[numDocs1]; + for (int j = 0; j < numDocs1; ++j) { + scores1[j] = random().nextFloat(); + } + + final int maxDocId2 = TestUtil.nextInt(random(), 2, 10_000); + final int numDocs2 = RandomizedTest.randomIntBetween(1, maxDocId2 / 2); + final int[] docs2 = new int[numDocs2]; + final Set uniqueDocs2 = new HashSet<>(); + while (uniqueDocs2.size() < numDocs2) { + uniqueDocs2.add(random().nextInt(maxDocId2)); + } + i = 0; + for (int doc : uniqueDocs2) { + docs2[i++] = doc; + } + Arrays.sort(docs2); + final float[] scores2 = new float[numDocs2]; + for (int j = 0; j < numDocs2; ++j) { + scores2[j] = random().nextFloat(); + } + + Set uniqueDocsAll = new HashSet<>(); + uniqueDocsAll.addAll(uniqueDocs1); + uniqueDocsAll.addAll(uniqueDocs2); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + mock(Weight.class), + new Scorer[] { + scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())), + scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery())) } + ); + int doc = -1; + int numOfActualDocs = 0; + while (doc != NO_MORE_DOCS) { + doc = hybridQueryScorer.iterator().nextDoc(); + if (doc == DocIdSetIterator.NO_MORE_DOCS) { + continue; + } + float[] actualTotalScores = hybridQueryScorer.hybridScores(); + float actualTotalScore = 0.0f; + for (float score : actualTotalScores) { + actualTotalScore += score; + } + float expectedScore = 0.0f; + if (uniqueDocs1.contains(doc)) { + int idx = Arrays.binarySearch(docs1, doc); + expectedScore += scores1[idx]; + } + if (uniqueDocs2.contains(doc)) { + int idx = Arrays.binarySearch(docs2, doc); + expectedScore += scores2[idx]; + } + assertEquals(expectedScore, actualTotalScore, 0.001f); + numOfActualDocs++; + } + assertEquals(uniqueDocsAll.size(), numOfActualDocs); + } + + @SneakyThrows + public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenReturnSuccessfully() { + final int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); + final int[] docs = new int[numDocs]; + final Set uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDocId)); + } + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores = new float[numDocs]; + for (int j = 0; j < numDocs; ++j) { + scores[j] = random().nextFloat(); + } + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + new Scorer[] { null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null } + ); + int doc = -1; + int numOfActualDocs = 0; + while (doc != NO_MORE_DOCS) { + int target = doc + 1; + doc = hybridQueryScorer.iterator().nextDoc(); + int idx = Arrays.binarySearch(docs, target); + idx = (idx >= 0) ? idx : (-1 - idx); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores[idx], hybridQueryScorer.score(), 0f); + numOfActualDocs++; + } + } + assertEquals(numDocs, numOfActualDocs); + } + + private static Weight fakeWeight(Query query) { + return new Weight(query) { + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + private static DocIdSetIterator iterator(final int... docs) { + return new DocIdSetIterator() { + + int i = -1; + + @Override + public int nextDoc() throws IOException { + if (i + 1 == docs.length) { + return NO_MORE_DOCS; + } else { + return docs[++i]; + } + } + + @Override + public int docID() { + return i < 0 ? -1 : i == docs.length ? NO_MORE_DOCS : docs[i]; + } + + @Override + public long cost() { + return docs.length; + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + }; + } + + private static Scorer scorer(final int[] docs, final float[] scores, Weight weight) { + final DocIdSetIterator iterator = iterator(docs); + return new Scorer(weight) { + + int lastScoredDoc = -1; + + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public float score() throws IOException { + assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID()); + lastScoredDoc = docID(); + final int idx = Arrays.binarySearch(docs, docID()); + return scores[idx]; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return Float.MAX_VALUE; + } + }; + } +}