From c4d0a0c9c65c1bb119370c9266fab44af6cc52fe Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 14 Jul 2023 15:59:34 +0200 Subject: [PATCH] Add feature flag for QueryPhaseSearcher (#214) * Adding hybrid_search_enabled settings Signed-off-by: Martin Gaievski --- build.gradle | 2 + .../neuralsearch/plugin/NeuralSearch.java | 19 +++++-- .../query/HybridQueryBuilder.java | 6 +-- .../query/NeuralQueryBuilder.java | 6 +-- .../opensearch/neuralsearch/TestUtils.java | 2 +- .../common/BaseNeuralSearchIT.java | 2 +- .../plugin/NeuralSearchTests.java | 15 ++++-- .../query/HybridQueryBuilderTests.java | 15 +++--- .../neuralsearch/query/HybridQueryIT.java | 4 +- .../neuralsearch/query/HybridQueryTests.java | 2 +- .../query/NeuralQueryBuilderTests.java | 10 ++-- .../query/OpenSearchQueryTestCase.java | 9 +++- .../query/HybridQueryPhaseSearcherTests.java | 51 +++++++++++++++---- 13 files changed, 103 insertions(+), 40 deletions(-) diff --git a/build.gradle b/build.gradle index 1d8eca483..6e4b4ada4 100644 --- a/build.gradle +++ b/build.gradle @@ -253,6 +253,8 @@ testClusters.integTest { // Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due // to ml-commons memory circuit breaker exception jvmArgs("-Xms1g", "-Xmx1g") + // enable hybrid search for testing + systemProperty('neural_search_hybrid_search_enabled', 'true') } // Remote Integration Tests diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 21c4e5537..83f8c396b 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -16,7 +16,8 @@ import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; @@ -39,11 +40,19 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; +import com.google.common.annotations.VisibleForTesting; + /** * Neural Search plugin class */ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin { - + /** + * Gates the functionality of hybrid search + * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. + * Once that problem is resolved this feature flag can be removed. + */ + @VisibleForTesting + public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled"; private MLCommonsClientAccessor clientAccessor; @Override @@ -80,6 +89,10 @@ public Map getProcessors(Processor.Parameters paramet @Override public Optional getQueryPhaseSearcher() { - return Optional.of(new HybridQueryPhaseSearcher()); + if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) { + return Optional.of(new HybridQueryPhaseSearcher()); + } + // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core + return Optional.empty(); } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index e97c20dec..19540550d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -21,11 +21,11 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.lucene.search.Query; -import org.opensearch.common.ParsingException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.lucene.search.Queries; import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.AbstractQueryBuilder; diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 50a246348..a7548dbe5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -23,11 +23,11 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; import org.opensearch.action.ActionListener; -import org.opensearch.common.ParsingException; import org.opensearch.common.SetOnce; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.NumberFieldMapper; diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 4f0c5129d..257545132 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -9,8 +9,8 @@ import java.util.Map; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; public class TestUtils { diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 59414b49b..e1f907a5c 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -40,12 +40,12 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase; -import org.opensearch.rest.RestStatus; import com.google.common.collect.ImmutableList; diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index c2925bf43..c4b1d49f7 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -15,12 +15,12 @@ import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.SearchPlugin; import org.opensearch.search.query.QueryPhaseSearcher; -import org.opensearch.test.OpenSearchTestCase; -public class NeuralSearchTests extends OpenSearchTestCase { +public class NeuralSearchTests extends OpenSearchQueryTestCase { public void testQuerySpecs() { NeuralSearch plugin = new NeuralSearch(); @@ -37,8 +37,15 @@ public void testQueryPhaseSearcher() { Optional queryPhaseSearcher = plugin.getQueryPhaseSearcher(); assertNotNull(queryPhaseSearcher); - assertFalse(queryPhaseSearcher.isEmpty()); - assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher); + assertTrue(queryPhaseSearcher.isEmpty()); + + initFeatureFlags(); + + Optional queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher(); + + assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled); + assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty()); + assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher); } public void testProcessors() { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 06b07fed5..49d1ba974 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -28,25 +28,26 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; -import org.opensearch.common.ParsingException; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.FilterStreamInput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.FilterStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.index.Index; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.Index; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -82,6 +83,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) @@ -110,6 +112,7 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 830191156..3414210b9 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -138,7 +138,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, modelId.get(), - 1, + 3, null, null ); @@ -150,7 +150,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect( Map searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3); - assertEquals(2, getHitCount(searchResponseAsMap)); + assertEquals(3, getHitCount(searchResponseAsMap)); List> hitsNestedList = getNestedHits(searchResponseAsMap); List scores = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index 21287ccb7..4d24c2049 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -34,7 +34,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.search.QueryUtils; -import org.opensearch.index.Index; +import org.opensearch.core.index.Index; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 7f48e61f9..dda4713b9 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -31,14 +31,14 @@ import org.opensearch.action.ActionListener; import org.opensearch.client.Client; -import org.opensearch.common.ParsingException; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.FilterStreamInput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.FilterStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index e954004c7..c45c4adce 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -8,6 +8,7 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonMap; import static java.util.stream.Collectors.toList; +import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED; import java.io.IOException; import java.util.Arrays; @@ -28,10 +29,11 @@ import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.CheckedConsumer; -import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.SuppressForbidden; import org.opensearch.common.compress.CompressedXContent; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexSettings; import org.opensearch.index.analysis.AnalyzerScope; @@ -222,4 +224,9 @@ public float getMaxScore(int upTo) { } }; } + + @SuppressForbidden(reason = "manipulates system properties for testing") + protected static void initFeatureFlags() { + System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "true"); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index f08dc2332..f63df42c9 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -36,14 +36,15 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.index.Index; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.index.shard.ShardId; +import org.opensearch.index.shard.IndexShard; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; @@ -95,6 +96,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -102,10 +104,10 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -115,6 +117,12 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { ); when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -127,6 +135,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -155,6 +165,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -162,10 +173,10 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -175,7 +186,13 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() ); when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -216,6 +233,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -223,10 +241,10 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -237,8 +255,14 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -300,6 +324,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -307,10 +332,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -323,6 +348,12 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(searchContext.size()).thenReturn(4); QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean();