Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Score Normalization and Combination feature #241

Merged
Merged
Prev Previous commit
Next Next commit
Add tests for Hybrid Query (#192)
* Add integ and unit test for query

Signed-off-by: Martin Gaievski <[email protected]>

---------

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Jun 5, 2023
commit b757ff409d3d9840888a7dc0108e041c23e0c24d
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {
private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5;

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";

protected final ClassLoader classLoader = this.getClass().getClassLoader();

@@ -93,7 +94,7 @@ protected String uploadModel(String requestBody) throws Exception {
"/_plugins/_ml/models/_upload",
null,
toHttpEntity(requestBody),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
@@ -122,7 +123,7 @@ protected void loadModel(String modelId) throws Exception {
String.format(LOCALE, "/_plugins/_ml/models/%s/_load", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
@@ -170,7 +171,7 @@ protected float[] runInference(String modelId, String queryText) {
String.format(LOCALE, "/_plugins/_ml/_predict/text_embedding/%s", modelId),
null,
toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"],\"target_response\": [\"sentence_embedding\"]}", queryText)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

Map<String, Object> inferenceResJson = XContentHelper.convertToMap(
@@ -201,7 +202,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
indexName,
null,
toHttpEntity(indexConfiguration),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
@@ -225,7 +226,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro
modelId
)
),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
@@ -403,7 +404,7 @@ protected Map<String, Object> getTaskQueryResponse(String taskId) throws Excepti
String.format(LOCALE, "_plugins/_ml/tasks/%s", taskId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
return XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
@@ -491,4 +492,26 @@ protected static class KNNFieldConfig {
private final Integer dimension;
private final SpaceType spaceType;
}

@SneakyThrows
protected void deleteModel(String modelId) {
// need to undeploy first as model can be in use
makeRequest(
client(),
"POST",
String.format(LOCALE, "/_plugins/_ml/models/%s/_undeploy", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

makeRequest(
client(),
"DELETE",
String.format(LOCALE, "/_plugins/_ml/models/%s", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.plugin;

import java.util.List;

import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralSearchTests extends OpenSearchTestCase {

public void testQuerySpecs() {
NeuralSearch plugin = new NeuralSearch();
List<SearchPlugin.QuerySpec<?>> querySpecs = plugin.getQueries();

assertNotNull(querySpecs);
assertFalse(querySpecs.isEmpty());
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}
}
Original file line number Diff line number Diff line change
@@ -63,7 +63,8 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
static final Supplier<float[]> TEST_VECTOR_SUPPLIER = () -> new float[4];
static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder();

public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() throws Exception {
@SneakyThrows
public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() {
HybridQueryBuilder queryBuilder = new HybridQueryBuilder();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
@@ -72,7 +73,8 @@ public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() throws Except
assertTrue(queryNoSubQueries instanceof MatchNoDocsQuery);
}

public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() throws Exception {
@SneakyThrows
public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
HybridQueryBuilder queryBuilder = new HybridQueryBuilder();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
@@ -99,7 +101,8 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() throws Excepti
assertNotNull(knnQuery.getQueryVector());
}

public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() throws Exception {
@SneakyThrows
public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
HybridQueryBuilder queryBuilder = new HybridQueryBuilder();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
@@ -141,7 +144,8 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() throws
assertEquals(TERM_QUERY_TEXT, termQuery.getTerm().text());
}

public void testDoToQuery_whenTooManySubqueries_thenFail() throws Exception {
@SneakyThrows
public void testDoToQuery_whenTooManySubqueries_thenFail() {
// create query with 6 sub-queries, which is more than current max allowed
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
@@ -225,12 +229,13 @@ public void testDoToQuery_whenTooManySubqueries_thenFail() throws Exception {
* }
* }
*/
public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() throws Exception {
@SneakyThrows
public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startArray("queries")
.startObject()
.startObject("neural")
.startObject(NeuralQueryBuilder.NAME)
.startObject(VECTOR_FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
@@ -240,7 +245,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() thro
.endObject()
.endObject()
.startObject()
.startObject("term")
.startObject(TermQueryBuilder.NAME)
.field(TEXT_FIELD_NAME, TERM_QUERY_TEXT)
.endObject()
.endObject()
@@ -287,7 +292,65 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() thro
}

@SneakyThrows
public void testToXContent() {
public void testFromXContent_whenIncorrectFormat_thenFail() {
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startArray("random_field")
.startObject()
.startObject(NeuralQueryBuilder.NAME)
.startObject(VECTOR_FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
.field(K_FIELD.getPreferredName(), K)
.field(BOOST_FIELD.getPreferredName(), BOOST)
.endObject()
.endObject()
.endObject()
.endArray()
.endObject();

NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
List.of(
new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent),
new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField(NeuralQueryBuilder.NAME),
NeuralQueryBuilder::fromXContent
),
new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField(HybridQueryBuilder.NAME),
HybridQueryBuilder::fromXContent
)
)
);
XContentParser contentParser = createParser(
namedXContentRegistry,
unsupportedFieldXContentBuilder.contentType().xContent(),
BytesReference.bytes(unsupportedFieldXContentBuilder)
);
contentParser.nextToken();

expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser));

XContentBuilder emptySubQueriesXContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startArray("queries")
.endArray()
.endObject();

XContentParser contentParser2 = createParser(
namedXContentRegistry,
unsupportedFieldXContentBuilder.contentType().xContent(),
BytesReference.bytes(emptySubQueriesXContentBuilder)
);
contentParser2.nextToken();

expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser2));
}

@SneakyThrows
public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() {
HybridQueryBuilder queryBuilder = new HybridQueryBuilder();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
@@ -344,7 +407,7 @@ public void testToXContent() {
}

@SneakyThrows
public void testStreams() {
public void testStreams_whenWrittingToStream_thenSuccessful() {
HybridQueryBuilder original = new HybridQueryBuilder();
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
.queryText(QUERY_TEXT)
Loading