diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b9af1074..abf123930 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.13...2.x) ### Features ### Enhancements +- Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670)) ### Bug Fixes - Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 01d271cdd..db09f6ebc 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -29,12 +29,32 @@ public final class HybridQuery extends Query implements Iterable { private final List subQueries; - public HybridQuery(Collection subQueries) { + /** + * Create new instance of hybrid query object based on collection of sub queries and filter query + * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores + * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is + */ + public HybridQuery(final Collection subQueries, final List filterQueries) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } - this.subQueries = new ArrayList<>(subQueries); + if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { + this.subQueries = new ArrayList<>(subQueries); + } else { + List modifiedSubQueries = new ArrayList<>(); + for (Query subQuery : subQueries) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(subQuery, BooleanClause.Occur.MUST); + filterQueries.forEach(filterQuery -> builder.add(filterQuery, BooleanClause.Occur.FILTER)); + modifiedSubQueries.add(builder.build()); + } + this.subQueries = modifiedSubQueries; + } + } + + public HybridQuery(final Collection subQueries) { + this(subQueries, List.of()); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 4d8b429df..b97134f8f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -7,6 +7,7 @@ import java.io.IOException; import java.util.LinkedList; import java.util.List; +import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.search.BooleanClause; @@ -14,7 +15,6 @@ import org.apache.lucene.search.Query; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; -import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; @@ -25,6 +25,8 @@ import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasAliasFilter; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasNestedFieldOrNestedDocs; import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery; /** @@ -51,10 +53,6 @@ public boolean searchWith( } } - private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { - return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); - } - private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); @@ -62,15 +60,20 @@ private static boolean isWrappedHybridQuery(final Query query) { @VisibleForTesting protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { - if (hasNestedFieldOrNestedDocs(query, searchContext) + if ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) && isWrappedHybridQuery(query) - && ((BooleanQuery) query).clauses().size() > 0) { - // extract hybrid query and replace bool with hybrid query + && !((BooleanQuery) query).clauses().isEmpty()) { List booleanClauses = ((BooleanQuery) query).clauses(); - if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) { - throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query"); + if (!(booleanClauses.get(0).getQuery() instanceof HybridQuery)) { + throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level query"); } - return booleanClauses.get(0).getQuery(); + HybridQuery hybridQuery = (HybridQuery) booleanClauses.stream().findFirst().get().getQuery(); + List filterQueries = booleanClauses.stream() + .filter(clause -> BooleanClause.Occur.FILTER == clause.getOccur()) + .map(BooleanClause::getQuery) + .collect(Collectors.toList()); + HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries); + return hybridQueryWithFilter; } return query; } diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java index 689cbedca..d19985f5c 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -6,15 +6,14 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; -import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.search.internal.SearchContext; +import java.util.Objects; + /** * Utility class for anything related to hybrid query */ @@ -24,7 +23,7 @@ public class HybridQueryUtil { public static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; - } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { + } else if (isWrappedHybridQuery(query)) { /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks @@ -34,33 +33,26 @@ public static boolean isHybridQuery(final Query query, final SearchContext searc below is sample structure of such query: Boolean { - should: { - hybrid: { - sub_query1 {} - sub_query2 {} - } - } - filter: { - exists: { - field: "_primary_term" - } - } + should: { + hybrid: { + sub_query1 {} + sub_query2 {} + } + } + filter: { + exists: { + field: "_primary_term" + } + } } - TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ + */ // we have already checked if query in instance of Boolean in higher level else if condition - return ((BooleanQuery) query).clauses() - .stream() - .filter(clause -> clause.getQuery() instanceof HybridQuery == false) - .allMatch(clause -> { - return clause.getOccur() == BooleanClause.Occur.FILTER - && clause.getQuery() instanceof FieldExistsQuery - && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); - }); + return hasNestedFieldOrNestedDocs(query, searchContext) || hasAliasFilter(query, searchContext); } return false; } - private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + public static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } @@ -68,4 +60,8 @@ private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } + + public static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { + return Objects.nonNull(searchContext.aliasFilter()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 38aa69075..ea59392da 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -28,6 +28,7 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -40,6 +41,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-hybrid-basic-index"; private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index"; + private static final String TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME = "test-hybrid-multi-doc-nested-fields-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-hybrid-multi-doc-index"; private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-hybrid-multi-doc-single-shard-index"; private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = @@ -256,7 +258,7 @@ public void testComplexQuery_whenMultipleIdenticalSubQueries_thenSuccessful() { public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { String modelId = null; try { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + initializeIndexIfNotExist(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME); modelId = prepareModel(); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); @@ -266,7 +268,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( hybridQueryBuilderOnlyTerm.add(termQuery2Builder); Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_NAME, + TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, hybridQueryBuilderOnlyTerm, null, 10, @@ -283,7 +285,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( assertNotNull(total.get("relation")); assertEquals(RELATION_EQUAL_TO, total.get("relation")); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, modelId, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, null, modelId, SEARCH_PIPELINE); } } @@ -578,6 +580,102 @@ public void testRequestCache_whenMultipleShardsQueryReturnResults_thenSuccessful } } + @SneakyThrows + public void testWrappedQueryWithFilter_whenIndexAliasHasFilterAndIndexWithNestedFields_thenSuccess() { + String modelId = null; + String alias = "alias_with_filter"; + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME); + modelId = prepareModel(); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + // create alias for index + QueryBuilder aliasFilter = QueryBuilders.boolQuery() + .mustNot(QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + createIndexAlias(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, alias, aliasFilter); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + "", + modelId, + 5, + null, + null + ); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + + Map searchResponseAsMap = search( + alias, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertEquals(1.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + deleteIndexAlias(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, alias); + wipeOfTestResources(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, null, modelId, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testWrappedQueryWithFilter_whenIndexAliasHasFilters_thenSuccess() { + String modelId = null; + String alias = "alias_with_filter"; + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + modelId = prepareModel(); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + // create alias for index + QueryBuilder aliasFilter = QueryBuilders.boolQuery() + .mustNot(QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + createIndexAlias(TEST_MULTI_DOC_INDEX_NAME, alias, aliasFilter); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + "", + modelId, + 5, + null, + null + ); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + + Map searchResponseAsMap = search( + alias, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertEquals(1.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + deleteIndexAlias(TEST_MULTI_DOC_INDEX_NAME, alias); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, modelId, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -628,10 +726,28 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { assertEquals(3, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); } + if (TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + 1 + ), + "" + ); + addDocsToIndex(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME); + } + if (TEST_MULTI_DOC_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME)) { - prepareKnnIndex( - TEST_MULTI_DOC_INDEX_NAME, - Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(), + 1 + ), + "" ); addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index 4f645f570..b74bd010c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -24,7 +24,9 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; @@ -276,4 +278,40 @@ public void testToString_whenCallQueryToString_thenSuccessful() { String queryString = query.toString(TEXT_FIELD_NAME); assertEquals("(keyword | anotherkeyword | (keyword anotherkeyword))", queryString); } + + @SneakyThrows + public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Query filter = QueryBuilders.boolQuery().mustNot(QueryBuilders.matchQuery(TERM_QUERY_TEXT, "test")).toQuery(mockQueryShardContext); + + HybridQuery hybridQuery = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) + ), + List.of(filter) + ); + QueryUtils.check(hybridQuery); + + Iterator queryIterator = hybridQuery.iterator(); + assertNotNull(queryIterator); + int countOfQueries = 0; + while (queryIterator.hasNext()) { + Query query = queryIterator.next(); + assertNotNull(query); + assertTrue(query instanceof BooleanQuery); + BooleanQuery booleanQuery = (BooleanQuery) query; + assertEquals(2, booleanQuery.clauses().size()); + Query subQuery = booleanQuery.clauses().get(0).getQuery(); + assertTrue(subQuery instanceof TermQuery); + Query filterQuery = booleanQuery.clauses().get(1).getQuery(); + assertTrue(filterQuery instanceof BooleanQuery); + assertTrue(((BooleanQuery) filterQuery).clauses().get(0).getQuery() instanceof MatchNoDocsQuery); + countOfQueries++; + } + assertEquals(2, countOfQueries); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 055301832..6fbc86dea 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -15,11 +15,10 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; +import static org.opensearch.index.mapper.SeqNoFieldMapper.PRIMARY_TERM_NAME; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedList; import java.util.List; @@ -27,7 +26,10 @@ import java.util.UUID; import java.util.stream.Collectors; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexOptions; @@ -39,7 +41,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; @@ -602,7 +603,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur org.hamcrest.MatcherAssert.assertThat( exception.getMessage(), - containsString("cannot process hybrid query due to incorrect structure of top level bool query") + containsString("cannot process hybrid query due to incorrect structure of top level query") ); releaseResources(directory, w, reader); @@ -632,6 +633,12 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); when(mockQueryShardContext.getMapperService()).thenReturn(mapperService); when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME)); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); Directory directory = newDirectory(); IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); @@ -643,10 +650,10 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then int docId2 = RandomizedTest.randomInt(); int docId3 = RandomizedTest.randomInt(); int docId4 = RandomizedTest.randomInt(); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.addDocument(document(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(document(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(document(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(document(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); w.commit(); IndexReader reader = DirectoryReader.open(w); @@ -681,6 +688,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -692,7 +700,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); BooleanQuery.Builder builder = new BooleanQuery.Builder(); - builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD) + builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.MUST) .add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER); Query query = builder.build(); @@ -887,19 +895,107 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { } @SneakyThrows - private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { - assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); - assertNotNull(subQueryTopDocs.scoreDocs); - assertEquals(expectedDocIds.size(), subQueryTopDocs.scoreDocs.length); - assertEquals(TotalHits.Relation.EQUAL_TO, subQueryTopDocs.totalHits.relation); - for (int i = 0; i < expectedDocIds.size(); i++) { - int expectedDocId = expectedDocIds.get(i); - ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[i]; - assertNotNull(scoreDoc); - int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); - assertEquals(expectedDocId, actualDocId); - assertTrue(scoreDoc.score > 0.0f); - } + public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + })); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + when(mockQueryShardContext.getMapperService()).thenReturn(mapperService); + when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME)); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + Query termFilter = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD).add(termFilter, BooleanClause.Occur.FILTER); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + when(searchContext.aliasFilter()).thenReturn(termFilter); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; + assertTrue(scoreDoc.score > 0); + assertEquals(0, scoreDoc.doc); + + releaseResources(directory, w, reader); } private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { @@ -908,28 +1004,6 @@ private void releaseResources(Directory directory, IndexWriter w, IndexReader re directory.close(); } - private List getSubQueryResultsForSingleShard(final TopDocs topDocs) { - assertNotNull(topDocs); - List topDocsList = new ArrayList<>(); - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - // skipping 0 element, it's a start-stop element - List scoreDocList = new ArrayList<>(); - for (int index = 2; index < scoreDocs.length; index++) { - // getting first element of score's series - ScoreDoc scoreDoc = scoreDocs[index]; - if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) { - ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]); - TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO); - TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores); - topDocsList.add(subQueryTopDocs); - scoreDocList.clear(); - } else { - scoreDocList.add(scoreDoc); - } - } - return topDocsList; - } - private BooleanQuery createNestedBoolQuery(final Query query1, final Query query2, int level) { if (level == 0) { BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -940,4 +1014,12 @@ private BooleanQuery createNestedBoolQuery(final Query query1, final Query query builder.add(createNestedBoolQuery(query1, query2, level - 1), BooleanClause.Occur.MUST); return builder.build(); } + + private static Document document(final String fieldName, int docId, final String fieldValue, final FieldType ft) { + Document doc = new Document(); + doc.add(new TextField("id", Integer.toString(docId), Field.Store.YES)); + doc.add(new Field(fieldName, fieldValue, ft)); + doc.add(new NumericDocValuesField(PRIMARY_TERM_NAME, 0)); + return doc; + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index ea006db57..2de428b17 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -348,6 +348,43 @@ protected void createSearchPipelineViaConfig(String modelId, String pipelineName assertEquals("true", node.get("acknowledged").toString()); } + protected void createIndexAlias(final String index, final String alias, final QueryBuilder filterBuilder) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.startArray("actions"); + builder.startObject(); + builder.startObject("add"); + builder.field("index", index); + builder.field("alias", alias); + // filter object + if (Objects.nonNull(filterBuilder)) { + builder.field("filter"); + filterBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + } + builder.endObject(); + builder.endObject(); + builder.endArray(); + builder.endObject(); + + Request request = new Request("POST", "/_aliases"); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + @SneakyThrows + protected void deleteIndexAlias(final String index, final String alias) { + makeRequest( + client(), + "DELETE", + String.format(Locale.ROOT, "%s/_alias/%s", index, alias), + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + /** * Get the number of documents in a particular index *