From 3c7f275e4d12e8dade324645ad2831d076c752bf Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 25 Nov 2024 16:09:44 -0800 Subject: [PATCH] Combining all changes in one PR after bad rebase (#998) Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/query/HybridQueryScorer.java | 13 +- .../collector/HybridTopScoreDocCollector.java | 9 + .../query/HybridQueryBuilderTests.java | 64 +++ .../query/HybridQueryScorerTests.java | 390 +++++++++++++++++- 5 files changed, 471 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..c127ef7d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index eb410aa23..48c69b618 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.query; +import com.google.common.annotations.VisibleForTesting; import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.DisiPriorityQueue; @@ -30,7 +31,7 @@ * corresponds to order of sub-queries in an input Hybrid query. */ @Log4j2 -public final class HybridQueryScorer extends Scorer { +public class HybridQueryScorer extends Scorer { // score for each of sub-query in this hybrid query @Getter @@ -100,7 +101,8 @@ public float score() throws IOException { return score(getSubMatches()); } - private float score(DisiWrapper topList) throws IOException { + @VisibleForTesting + float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue @@ -187,7 +189,12 @@ public int docID() { */ public float[] hybridScores() throws IOException { float[] scores = new float[numSubqueries]; - DisiWrapper topList = subScorersPQ.topList(); + // retrieves sub-matches using DisjunctionDisiScorer's two-phase iteration process. + // while the two-phase iterator can efficiently skip blocks of document IDs during matching, + // the DisiWrapper (obtained from subScorersPQ.topList()) ensures sequential document ID iteration. + // this is necessary for maintaining correct scoring order. + DisiWrapper topList = getSubMatches(); + for (HybridDisiWrapper disiWrapper = (HybridDisiWrapper) topList; disiWrapper != null; disiWrapper = (HybridDisiWrapper) disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java index 4e72b55bf..7a3585cf9 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java @@ -108,12 +108,21 @@ public void collect(int doc) throws IOException { } // Increment total hit count which represents unique doc found on the shard totalHits++; + hitsThresholdChecker.incrementHitCount(); for (int i = 0; i < subScoresByQuery.length; i++) { float score = subScoresByQuery[i]; // if score is 0.0 there is no hits for that sub-query if (score == 0) { continue; } + if (hitsThresholdChecker.isThresholdReached() && totalHitsRelation == TotalHits.Relation.EQUAL_TO) { + log.info( + "hit count threshold reached: total hits={}, threshold={}, action=updating_results", + totalHits, + hitsThresholdChecker.getTotalHitsThreshold() + ); + totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + } collectedHitsPerSubQuery[i]++; PriorityQueue pq = compoundScores[i]; ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 2a6fa49a3..a26dd8263 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -50,6 +50,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -756,6 +757,69 @@ public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { assertNotNull(hybridQueryBuilder); } + @SneakyThrows + public void testBuild_whenValidParameters_thenCreateQuery() { + String queryText = "test query"; + String modelId = "test_model"; + String fieldName = "rank_features"; + + // Create mock context + QueryShardContext context = mock(QueryShardContext.class); + MappedFieldType fieldType = mock(MappedFieldType.class); + when(context.fieldMapper(fieldName)).thenReturn(fieldType); + when(fieldType.typeName()).thenReturn("rank_features"); + + // Create HybridQueryBuilder instance (no spy since it's final) + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName(fieldName) + .queryText(queryText) + .modelId(modelId) + .queryTokensSupplier(() -> Map.of("token1", 1.0f, "token2", 0.5f)); + HybridQueryBuilder builder = new HybridQueryBuilder().add(neuralSparseQueryBuilder); + + // Build query + Query query = builder.toQuery(context); + + // Verify + assertNotNull("Query should not be null", query); + assertTrue("Should be HybridQuery", query instanceof HybridQuery); + } + + @SneakyThrows + public void testDoEquals_whenSameParameters_thenEqual() { + // Create neural queries + NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder().queryText("test").modelId("test_model"); + + NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder().queryText("test").modelId("test_model"); + + // Create neural sparse queries with queryTokensSupplier + NeuralSparseQueryBuilder neuralSparseQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName("test_field") + .queryText("test") + .modelId("test_model") + .queryTokensSupplier(() -> Map.of("token1", 1.0f)); + + NeuralSparseQueryBuilder neuralSparseQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName("test_field") + .queryText("test") + .modelId("test_model") + .queryTokensSupplier(() -> Map.of("token1", 1.0f)); + + // Create builders + HybridQueryBuilder builder1 = new HybridQueryBuilder().add(neuralQueryBuilder1).add(neuralSparseQueryBuilder1); + + HybridQueryBuilder builder2 = new HybridQueryBuilder().add(neuralQueryBuilder2).add(neuralSparseQueryBuilder2); + + // Verify + assertTrue("Builders should be equal", builder1.equals(builder2)); + assertEquals("Hash codes should match", builder1.hashCode(), builder2.hashCode()); + } + + public void testValidate_whenInvalidParameters_thenThrowException() { + // Test null query builder + HybridQueryBuilder builderWithNull = new HybridQueryBuilder(); + IllegalArgumentException nullException = assertThrows(IllegalArgumentException.class, () -> builderWithNull.add(null)); + assertEquals("inner hybrid query clause cannot be null", nullException.getMessage()); + } + public void testVisit() { HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(new NeuralQueryBuilder()).add(new NeuralSparseQueryBuilder()); List visitedQueries = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index e7325055e..5bf21c553 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -7,19 +7,28 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; @@ -28,6 +37,7 @@ import com.carrotsearch.randomizedtesting.RandomizedTest; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.search.HybridDisiWrapper; public class HybridQueryScorerTests extends OpenSearchQueryTestCase { @@ -91,7 +101,7 @@ public void testWithRandomDocumentsAndHybridScores_whenMultipleScorers_thenRetur int idx = Arrays.binarySearch(docs2, doc); expectedScore += scores2[idx]; } - assertEquals(expectedScore, actualTotalScore, 0.001f); + assertEquals(expectedScore, actualTotalScore, DELTA_FOR_SCORE_ASSERTION); numOfActualDocs++; } @@ -146,7 +156,7 @@ public void testWithRandomDocumentsAndCombinedScore_whenMultipleScorers_thenRetu for (float score : hybridQueryScorer.hybridScores()) { hybridScore += score; } - assertEquals(expectedScore, hybridScore, 0.001f); + assertEquals(expectedScore, hybridScore, DELTA_FOR_SCORE_ASSERTION); numOfActualDocs++; } @@ -269,12 +279,386 @@ public void testApproximationIterator_whenSubScorerSupportsApproximation_thenSuc assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); } else { assertEquals(docs[idx], doc); - assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), 0.001f); + assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), DELTA_FOR_SCORE_ASSERTION); } idx++; } } + @SneakyThrows + public void testScore_whenMultipleSubScorers_thenSumScores() { + // Create mock scorers with iterators + Scorer scorer1 = mock(Scorer.class); + DocIdSetIterator iterator1 = mock(DocIdSetIterator.class); + when(scorer1.iterator()).thenReturn(iterator1); + when(scorer1.docID()).thenReturn(1); + when(scorer1.score()).thenReturn(0.5f); + + Scorer scorer2 = mock(Scorer.class); + DocIdSetIterator iterator2 = mock(DocIdSetIterator.class); + when(scorer2.iterator()).thenReturn(iterator2); + when(scorer2.docID()).thenReturn(1); + when(scorer2.score()).thenReturn(0.3f); + + // Create DisiWrapper list + DisiWrapper wrapper1 = new DisiWrapper(scorer1); + wrapper1.next = new DisiWrapper(scorer2); + + Weight weight = mock(Weight.class); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Arrays.asList(scorer1, scorer2)); + float score = hybridScorer.score(wrapper1); + + assertEquals("Combined score should be sum of individual scores", 0.8f, score, DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testScore_whenNoMoreDocs_thenReturnZero() { + // Create mock scorer + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(scorer.iterator()).thenReturn(iterator); + + // Setup iterator behavior + when(iterator.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(iterator.cost()).thenReturn(1L); + + // Create TwoPhaseIterator if needed + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + when(approximation.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(approximation.cost()).thenReturn(1L); + when(twoPhase.approximation()).thenReturn(approximation); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + + // Create wrapper + DisiWrapper wrapper = new DisiWrapper(scorer); + + // Create weight mock + Weight weight = mock(Weight.class); + + // Create HybridQueryScorer + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer)); + + // Test score method + float score = hybridScorer.score(wrapper); + + // Verify + assertEquals("Score should be 0.0 for NO_MORE_DOCS", 0.0f, score, DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testGetSubMatches_whenNoScorers_thenReturnNull() { + Weight weight = mock(Weight.class); + + // Create a scorer with a two-phase iterator that doesn't match + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + when(twoPhase.matches()).thenReturn(false); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + when(scorer.iterator()).thenReturn(iterator); + when(scorer.docID()).thenReturn(0); // Set a valid docID + + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer), ScoreMode.TOP_SCORES); + + DisiWrapper result = hybridScorer.getSubMatches(); + assertNull("Should return null when no matches are available", result); + } + + @SneakyThrows + public void testGetSubMatches_whenTwoPhaseIteratorPresent_thenReturnWrapper() { + // Create weight mock + Weight weight = mock(Weight.class); + + // Create scorer with iterator + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + + // Setup iterator behavior with AtomicInteger for state tracking + AtomicInteger currentDoc = new AtomicInteger(-1); + + when(iterator.docID()).thenAnswer(inv -> currentDoc.get()); + when(iterator.cost()).thenReturn(1L); + when(iterator.nextDoc()).thenAnswer(inv -> { + if (currentDoc.get() == -1) { + currentDoc.set(0); + return 0; + } + return DocIdSetIterator.NO_MORE_DOCS; + }); + + when(scorer.iterator()).thenReturn(iterator); + when(scorer.docID()).thenAnswer(inv -> currentDoc.get()); + + // Create and setup TwoPhaseIterator + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + + // Setup approximation behavior + when(approximation.docID()).thenAnswer(inv -> currentDoc.get()); + when(approximation.cost()).thenReturn(1L); + when(approximation.nextDoc()).thenAnswer(inv -> { + if (currentDoc.get() == -1) { + currentDoc.set(0); + return 0; + } + return DocIdSetIterator.NO_MORE_DOCS; + }); + + when(twoPhase.approximation()).thenReturn(approximation); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + when(twoPhase.matches()).thenReturn(true); + + // Create HybridQueryScorer + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer), ScoreMode.TOP_SCORES); + + // Initialize the scorer by moving to first doc + DocIdSetIterator scorerIterator = hybridScorer.iterator(); + int firstDoc = scorerIterator.nextDoc(); + + // Verify initial state + assertEquals("First doc should be 0", 0, firstDoc); + assertEquals("Iterator should be at doc 0", 0, scorerIterator.docID()); + + // Get submatches + DisiWrapper result = hybridScorer.getSubMatches(); + + // Verify + assertNotNull("Should not be null when twoPhase is present", result); + assertTrue("Should be instance of HybridDisiWrapper", result instanceof HybridDisiWrapper); + assertNotNull("TwoPhaseView should not be null", result.twoPhaseView); + assertEquals("Should be at doc 0", 0, result.doc); + + // Verify the two-phase iterator + TwoPhaseIterator resultTwoPhase = result.twoPhaseView; + assertNotNull("Two-phase iterator should not be null", resultTwoPhase); + assertTrue("Should match", resultTwoPhase.matches()); + } + + @SneakyThrows + public void testAdvanceShallow_whenTargetProvided_thenReturnTarget() { + Weight weight = mock(Weight.class); + + // Create scorer + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(scorer.iterator()).thenReturn(iterator); + + // Create and setup TwoPhaseIterator + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + when(twoPhase.approximation()).thenReturn(approximation); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + when(twoPhase.matches()).thenReturn(true); + + // Setup initial state + AtomicInteger currentDoc = new AtomicInteger(-1); + + // Setup iterator behavior + when(iterator.docID()).thenAnswer(inv -> currentDoc.get()); + when(approximation.docID()).thenAnswer(inv -> currentDoc.get()); + + // Setup nextDoc behavior + when(iterator.nextDoc()).thenAnswer(inv -> { + currentDoc.set(0); + return 0; + }); + + when(approximation.nextDoc()).thenAnswer(inv -> { + currentDoc.set(0); + return 0; + }); + + // Setup advance behavior + int target = 5; + when(approximation.advance(target)).thenAnswer(inv -> { + currentDoc.set(target); + return target; + }); + + when(iterator.advance(target)).thenAnswer(inv -> { + currentDoc.set(target); + return target; + }); + + // Setup costs + when(iterator.cost()).thenReturn(1L); + when(approximation.cost()).thenReturn(1L); + + // Create hybrid scorer with custom advanceShallow implementation + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer), ScoreMode.TOP_SCORES) { + @Override + public float score() throws IOException { + return 1.0f; + } + + @Override + public int advanceShallow(int target) throws IOException { + DisiWrapper lead = getSubMatches(); + if (lead != null && lead.twoPhaseView != null) { + DocIdSetIterator approx = lead.twoPhaseView.approximation(); + int result = approx.advance(target); + return result; + } + return 0; + } + }; + + // Initialize scorer + DocIdSetIterator scorerIterator = hybridScorer.iterator(); + + // Move to first doc + int firstDoc = scorerIterator.nextDoc(); + assertEquals("Should be at first doc", 0, scorerIterator.docID()); + + // Test advanceShallow + int result = hybridScorer.advanceShallow(target); + + // Verify + assertEquals("AdvanceShallow should return the target", target, result); + verify(approximation).advance(target); + assertEquals("Current doc should be at target", target, currentDoc.get()); + } + + @SneakyThrows + public void testScore_whenMultipleQueries_thenCombineScores() { + // Create mock scorers for different queries + Scorer boolScorer = mock(Scorer.class); + DocIdSetIterator boolIterator = mock(DocIdSetIterator.class); + when(boolScorer.iterator()).thenReturn(boolIterator); + when(boolScorer.docID()).thenReturn(1); + when(boolScorer.score()).thenReturn(0.7f); + + Scorer neuralScorer = mock(Scorer.class); + DocIdSetIterator neuralIterator = mock(DocIdSetIterator.class); + when(neuralScorer.iterator()).thenReturn(neuralIterator); + when(neuralScorer.docID()).thenReturn(1); + when(neuralScorer.score()).thenReturn(0.9f); + + // Create DisiWrapper chain + DisiWrapper boolWrapper = new DisiWrapper(boolScorer); + DisiWrapper neuralWrapper = new DisiWrapper(neuralScorer); + boolWrapper.next = neuralWrapper; + + Weight weight = mock(Weight.class); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Arrays.asList(boolScorer, neuralScorer), ScoreMode.COMPLETE); + float combinedScore = hybridScorer.score(boolWrapper); + + assertEquals("Combined score should be sum of bool and neural scores", 1.6f, combinedScore, DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testScore_whenEmptySubScorers_thenReturnZero() { + Weight weight = mock(Weight.class); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.emptyList()); + float score = hybridScorer.score(null); + + assertEquals("Score should be 0.0 for null wrapper", 0.0f, score, DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testInitialization_whenValidScorer_thenSuccessful() { + // Create scorer with iterator + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + + // Setup state tracking + AtomicInteger currentDoc = new AtomicInteger(-1); + + // Setup iterator behavior + when(iterator.docID()).thenAnswer(inv -> currentDoc.get()); + when(iterator.cost()).thenReturn(1L); + when(iterator.nextDoc()).thenAnswer(inv -> { + if (currentDoc.get() == -1) { + currentDoc.set(0); + return 0; + } + return DocIdSetIterator.NO_MORE_DOCS; + }); + + when(scorer.iterator()).thenReturn(iterator); + when(scorer.docID()).thenAnswer(inv -> currentDoc.get()); + + // Create wrapper + HybridDisiWrapper wrapper = new HybridDisiWrapper(scorer, 1); + + // Verify + assertNotNull("Wrapper should not be null", wrapper); + assertEquals("Initial doc should be -1", -1, wrapper.doc); + assertNotNull("Iterator should not be null", wrapper.iterator); + assertEquals("Cost should be 1", 1L, wrapper.cost); + } + + @SneakyThrows + public void testHybridScores_withTwoPhaseIterator() throws IOException { + // Create weight and scorers + Weight weight = mock(Weight.class); + Scorer scorer1 = mock(Scorer.class); + TwoPhaseIterator twoPhaseIterator = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + + // Setup two-phase behavior + when(scorer1.twoPhaseIterator()).thenReturn(twoPhaseIterator); + when(twoPhaseIterator.approximation()).thenReturn(approximation); + when(scorer1.iterator()).thenReturn(approximation); + when(approximation.cost()).thenReturn(1L); + + // Setup DocIdSetIterator behavior - use different docIDs + when(approximation.docID()).thenReturn(5); // approximation at doc 5 + when(scorer1.docID()).thenReturn(5); // scorer at same doc + when(scorer1.score()).thenReturn(2.0f); + + // matches() always returns false - document should never match + when(twoPhaseIterator.matches()).thenReturn(false); + + // Create HybridQueryScorer with two-phase iterator + List subScorers = Collections.singletonList(scorer1); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, subScorers); + + // Call matches() first to establish non-matching state + TwoPhaseIterator hybridTwoPhase = hybridScorer.twoPhaseIterator(); + assertNotNull("Should have two phase iterator", hybridTwoPhase); + assertFalse("Document should not match", hybridTwoPhase.matches()); + + // Get scores - should be zero since document doesn't match + float[] scores = hybridScorer.hybridScores(); + assertEquals("Should have one score entry", 1, scores.length); + assertEquals("Score should be 0 for non-matching document", 0.0f, scores[0], DELTA_FOR_SCORE_ASSERTION); + + // Verify score() was never called since document didn't match + verify(scorer1, never()).score(); + verify(twoPhaseIterator, times(1)).matches(); + } + + @SneakyThrows + public void testTwoPhaseIterator_withNestedTwoPhaseQuery() { + // Create a scorer that uses two-phase iteration + Scorer scorer = mock(Scorer.class); + TwoPhaseIterator twoPhaseIterator = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + + // Setup the two-phase behavior + when(scorer.twoPhaseIterator()).thenReturn(twoPhaseIterator); + when(twoPhaseIterator.approximation()).thenReturn(approximation); + when(twoPhaseIterator.matches()).thenReturn(true); + + // Mock iterator() method which is needed for cost calculation + when(scorer.iterator()).thenReturn(approximation); + // Mock cost to avoid NPE + when(approximation.cost()).thenReturn(1L); + + // Create wrapper + HybridDisiWrapper wrapper = new HybridDisiWrapper(scorer, 1); + + // This would return null before PR #998 + TwoPhaseIterator wrapperTwoPhase = wrapper.twoPhaseView; + assertNotNull("Two-phase iterator should not be null", wrapperTwoPhase); + + // Verify that the two-phase behavior is preserved + assertTrue("Should match", wrapperTwoPhase.matches()); + assertSame("Should use same approximation", approximation, wrapperTwoPhase.approximation()); + } + protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) { final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc); return new Scorer(weight) {