Skip to content

Commit

Permalink
Add check for empty sub-queries, fix rewrite logic
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed May 22, 2023
1 parent 6bd9f8e commit 5eb6bc3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ public final class HybridQuery extends Query implements Iterable<Query> {
private final List<Query> subQueries;

public HybridQuery(Collection<Query> subQueries) {
Objects.requireNonNull(subQueries, "Collection of Queries must not be null");
Objects.requireNonNull(subQueries, "Collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("Collection of queries must not be empty");
}
this.subQueries = new ArrayList<>(subQueries);
}

Expand Down Expand Up @@ -81,10 +84,6 @@ public Query rewrite(IndexReader reader) throws IOException {
return new MatchNoDocsQuery("empty HybridQuery");
}

if (subQueries.size() == 1) {
return subQueries.iterator().next();
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.Index;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -119,17 +118,8 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() throws

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT);
queryBuilder.add(termSubQuery);
MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
)
);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);

TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Query queryTwoSubQueries = queryBuilder.doToQuery(mockQueryShardContext);
assertNotNull(queryTwoSubQueries);
Expand Down Expand Up @@ -317,17 +307,8 @@ public void testToXContent() {

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT);
queryBuilder.add(termSubQuery);
MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
)
);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);

TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

XContentBuilder builder = XContentFactory.jsonBuilder();
Expand Down Expand Up @@ -532,17 +513,7 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery(
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
)
);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

QueryBuilder queryBuilderAfterRewrite = queryBuilder.doRewrite(mockQueryShardContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

package org.opensearch.neuralsearch.query;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.query.HybridQueryBuilderTests.QUERY_TEXT;
import static org.opensearch.neuralsearch.query.HybridQueryBuilderTests.TERM_QUERY_TEXT;
import static org.opensearch.neuralsearch.query.HybridQueryBuilderTests.TEXT_FIELD_NAME;

import java.io.IOException;
Expand All @@ -35,54 +35,49 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.search.QueryUtils;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.Index;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQueryBuilder;

public class HybridQueryTests extends OpenSearchQueryTestCase {

static final String VECTOR_FIELD_NAME = "vectorField";
static final String TERM_QUERY_TEXT = "keyword";
static final String TERM_ANOTHER_QUERY_TEXT = "anotherkeyword";
static final float[] VECTOR_QUERY = new float[] {1.0f, 2.0f, 2.1f, 0.6f};
static final int K = 2;

@SneakyThrows
public void testBasics() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
)
);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

HybridQuery query1 = new HybridQuery(List.of());
HybridQuery query2 = new HybridQuery(List.of());
HybridQuery query3 = new HybridQuery(
HybridQuery query1 = new HybridQuery(
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
HybridQuery query2 = new HybridQuery(
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
HybridQuery query3 = new HybridQuery(
List.of(
QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext),
QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext)
)
);
QueryUtils.check(query1);
QueryUtils.checkEqual(query1, query2);
QueryUtils.checkUnequal(query1, query3);
}

public void testRewrite() throws Exception {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
)
);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
Expand All @@ -98,13 +93,27 @@ public void testRewrite() throws Exception {
w.commit();

IndexReader reader = DirectoryReader.open(w);
HybridQuery query = new HybridQuery(
HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
Query rewritten = query.rewrite(reader);
QueryUtils.checkUnequal(query, rewritten);
Query rewritten2 = rewritten.rewrite(reader);
assertSame(rewritten, rewritten2);
Query rewritten = hybridQueryWithTerm.rewrite(reader);
//term query is the same after we rewrite it
assertSame(hybridQueryWithTerm, rewritten);


Index dummyIndex = new Index("dummy", "dummy");
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K);
Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext);

HybridQuery hybridQueryWithKnn = new HybridQuery(
List.of(knnQuery)
);
rewritten = hybridQueryWithKnn.rewrite(reader);
assertSame(hybridQueryWithKnn, rewritten);

w.close();
reader.close();
Expand Down Expand Up @@ -230,6 +239,12 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR
dir.close();
}

@SneakyThrows
public void testWithRandomDocuments_whenNoSubQueries_thenFail() {
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of()));
assertThat(exception.getMessage(), containsString("Collection of queries must not be empty"));
}

private static Document getDocument(int docId1, String field1Value, FieldType ft) {
Document doc = new Document();
doc.add(new TextField("id", Integer.toString(docId1), Field.Store.YES));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,17 @@ protected final XContentBuilder mapping(CheckedConsumer<XContentBuilder, IOExcep
protected MapperService createMapperService(XContentBuilder mappings) throws IOException {
return createMapperService(Version.CURRENT, mappings);
}

protected MapperService createMapperService() throws IOException {
return createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
)
);
}
}

0 comments on commit 5eb6bc3

Please sign in to comment.