Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Add AbstractRetriverTool, VectorDBTool, NeuralSparseTools #40

Merged
merged 15 commits into from
Dec 22, 2023
15 changes: 10 additions & 5 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ configurations {
zipArchive
all {
resolutionStrategy {
force "org.mockito:mockito-core:5.8.0"
force "org.mockito:mockito-core:${versions.mockito}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what will be the current version in this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

force "com.google.guava:guava:32.1.3-jre" // CVE for 31.1
force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0, forces JDK17 for spotless
}
Expand All @@ -103,12 +103,15 @@ dependencies {
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0"
compileOnly group: 'org.json', name: 'json', version: '20231013'
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}"
compileOnly "org.opensearch:common-utils:${version}"

implementation("com.google.guava:guava:32.1.3-jre")
implementation group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}"
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"])
compileOnly "org.opensearch:common-utils:${version}"

testImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation "org.mockito:mockito-core:5.8.0"
testImplementation "net.bytebuddy:byte-buddy-agent:${versions.bytebuddy}"
testImplementation "org.mockito:mockito-core:${versions.mockito}"
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1'
testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0'
testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0"
Expand All @@ -118,6 +121,7 @@ dependencies {

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}"
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
}

task extractSqlJar(type: Copy) {
Expand All @@ -137,12 +141,13 @@ testingConventions.enabled = false
thirdPartyAudit.enabled = false

test {
useJUnitPlatform()
testLogging {
exceptionFormat "full"
events "skipped", "passed", "failed" // "started"
showStandardStreams true
}
include '**/*Tests.class'
systemProperty 'tests.security.manager', 'false'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment what is this for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns off the security manager to make it easy for writing test case. It is used in many build script of Opensearch plugins. However we don't have such test cases currently. So I just removed this line and we can add it back when we need it.

}

spotless {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -54,11 +56,13 @@ public Collection<Object> createComponents(
this.xContentRegistry = xContentRegistry;

PPLTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
return Collections.emptyList();
}

@Override
public List<Tool.Factory<? extends Tool>> getToolFactories() {
return List.of(PPLTool.Factory.getInstance());
return List.of(PPLTool.Factory.getInstance(), NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance());
}
}
143 changes: 143 additions & 0 deletions src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* Abstract tool supports search paradigms in neural-search plugin.
*/
@Log4j2
@Getter
@Setter
public abstract class AbstractRetrieverTool implements Tool {
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index.";
public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String SOURCE_FIELD = "source_field";
public static final String DOC_SIZE_FIELD = "doc_size";
public static final int DEFAULT_DOC_SIZE = 2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to control the matching docs size? I remember it's 10 by default in search API. Are we giving 2 randomly or there's a reason behind this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This number is inherited from origin VectorDBTool. I guess the reason maybe we want to limit the context size for LLM. I'm open for changing this number

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's confirm and add a comment for this field. I'm also curious to know more about this field.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The origin VectorDBTool was introduced from the agent framework POC commit. opensearch-project/ml-commons@1b85cff#diff-b353ce1eda5b942809ea500dadcfda5769504edd8fa17fc174d498937097c69eR67 Hi @ylwu-amzn, do you know why the default doc size is 2? Do we set this for context length?

By the way I don't think this is a blocker issue for this PR. This parameter is configurable, and we can alter its value after our e2e test.


protected String description = DEFAULT_DESCRIPTION;
protected Client client;
protected NamedXContentRegistry xContentRegistry;
protected String index;
protected String[] sourceFields;
protected Integer docSize;
protected String version;

protected AbstractRetrieverTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String[] sourceFields,
Integer docSize
) {
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.sourceFields = sourceFields;
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
}

protected abstract String getQueryBody(String queryText);

private <T> SearchRequest buildSearchRequest(Map<String, String> parameters, ActionListener<T> listener) throws IOException {
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it.");
}

String query = getQueryBody(question);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.fetchSource(sourceFields, null);
searchSourceBuilder.size(docSize);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return searchRequest;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
SearchRequest searchRequest;
try {
searchRequest = buildSearchRequest(parameters, listener);
} catch (Exception e) {
listener.onFailure(e);
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
return;
}

ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason to extract these fields instead of deserialize the whole searchResponse as result?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/opensearch-project/OpenSearch/blob/2c8ee1947b55a1cc5bc1a114b82e3a3b8a99851e/server/src/main/java/org/opensearch/search/SearchHit.java#L616
If we just deserialize the object, the unnecessary fields may increase the size of context. E.g. nestedIdentity, version, seqNo, primaryTerm etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result looks like in jsonl format, is this as expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this result will be treated as a norm string and be sent to LLM.

docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
contextBuilder.append(gson.toJson(docContent)).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "");
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
}
}, e -> {
log.error("Failed to search index.", e);
listener.onFailure(e);
});
client.search(searchRequest, actionListener);
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
return false;
}
String question = parameters.get("input");
return !StringUtils.isBlank(question);
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
}

protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> {
protected Client client;
protected NamedXContentRegistry xContentRegistry;

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
110 changes: 110 additions & 0 deletions src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports neural_sparse search with sparse encoding models and rank_features field.
*/
@Log4j2
@Getter
@Setter
@ToolAnnotation(NeuralSparseSearchTool.TYPE)
public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";

private String name = TYPE;
private String modelId;
private String embeddingField;

@Builder
public NeuralSparseSearchTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String embeddingField,
String[] sourceFields,
Integer docSize,
String modelId
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
}

@Override
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
);
}
return "{\"query\":{\"neural_sparse\":{\""
+ embeddingField
+ "\":{\"query_text\":\""
+ queryText
+ "\",\"model_id\":\""
+ modelId
+ "\"}}}"
+ " }";
}

@Override
public String getType() {
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> {
private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (NeuralSparseSearchTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

@Override
public NeuralSparseSearchTool create(Map<String, Object> params) {
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
return NeuralSparseSearchTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.build();
}
}
}
Loading
Loading