-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: Add AbstractRetriverTool, VectorDBTool, NeuralSparseTools #40
Changes from 11 commits
7d1fb29
355bbe7
340a6ed
1114732
7fe33f6
fd3cf33
4311116
02f0d45
43a228f
97f92e9
de436a4
6ec9b0c
ef03c34
b1dbe84
b3a964b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,7 +82,7 @@ configurations { | |
zipArchive | ||
all { | ||
resolutionStrategy { | ||
force "org.mockito:mockito-core:5.8.0" | ||
force "org.mockito:mockito-core:${versions.mockito}" | ||
force "com.google.guava:guava:32.1.3-jre" // CVE for 31.1 | ||
force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0, forces JDK17 for spotless | ||
} | ||
|
@@ -103,12 +103,15 @@ dependencies { | |
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' | ||
compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0" | ||
compileOnly group: 'org.json', name: 'json', version: '20231013' | ||
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" | ||
implementation("com.google.guava:guava:32.1.3-jre") | ||
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"]) | ||
compileOnly "org.opensearch:common-utils:${version}" | ||
compileOnly("com.google.guava:guava:32.1.2-jre") | ||
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' | ||
|
||
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"]) | ||
|
||
testImplementation "org.opensearch.test:framework:${opensearch_version}" | ||
testImplementation "org.mockito:mockito-core:5.8.0" | ||
testImplementation "net.bytebuddy:byte-buddy-agent:${versions.bytebuddy}" | ||
testImplementation "org.mockito:mockito-core:${versions.mockito}" | ||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1' | ||
testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0' | ||
testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" | ||
|
@@ -118,6 +121,7 @@ dependencies { | |
|
||
// ZipArchive dependencies used for integration tests | ||
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" | ||
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
task extractSqlJar(type: Copy) { | ||
|
@@ -137,12 +141,13 @@ testingConventions.enabled = false | |
thirdPartyAudit.enabled = false | ||
|
||
test { | ||
useJUnitPlatform() | ||
testLogging { | ||
exceptionFormat "full" | ||
events "skipped", "passed", "failed" // "started" | ||
showStandardStreams true | ||
} | ||
include '**/*Tests.class' | ||
systemProperty 'tests.security.manager', 'false' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a comment what is this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns off the security manager to make it easy for writing test case. It is used in many build script of Opensearch plugins. However we don't have such test cases currently. So I just removed this line and we can add it back when we need it. |
||
} | ||
|
||
spotless { | ||
|
@@ -165,6 +170,8 @@ compileJava { | |
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) | ||
} | ||
|
||
forbiddenApisTest.ignoreFailures = true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment why do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's inherited from ml-commons build script. I try to build without this line and get errors like
I tried to do some research but can not get any useful information. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And if I removes the @test annotation, it just can not find the test case |
||
|
||
compileTestJava { | ||
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Abstract tool supports search paradigms in neural-search plugin. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
public abstract class AbstractRetrieverTool implements Tool { | ||
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; | ||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String SOURCE_FIELD = "source_field"; | ||
public static final String DOC_SIZE_FIELD = "doc_size"; | ||
public static final int DEFAULT_DOC_SIZE = 2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to control the matching docs size? I remember it's 10 by default in search API. Are we giving 2 randomly or there's a reason behind this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This number is inherited from origin VectorDBTool. I guess the reason maybe we want to limit the context size for LLM. I'm open for changing this number There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's confirm and add a comment for this field. I'm also curious to know more about this field. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The origin VectorDBTool was introduced from the agent framework POC commit. opensearch-project/ml-commons@1b85cff#diff-b353ce1eda5b942809ea500dadcfda5769504edd8fa17fc174d498937097c69eR67 Hi @ylwu-amzn, do you know why the default doc size is 2? Do we set this for context length? By the way I don't think this is a blocker issue for this PR. This parameter is configurable, and we can alter its value after our e2e test. |
||
|
||
protected String description = DEFAULT_DESCRIPTION; | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
protected String index; | ||
protected String[] sourceFields; | ||
protected Integer docSize; | ||
protected String version; | ||
|
||
protected AbstractRetrieverTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String[] sourceFields, | ||
Integer docSize | ||
) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.index = index; | ||
this.sourceFields = sourceFields; | ||
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; | ||
} | ||
|
||
protected abstract String getQueryBody(String queryText); | ||
|
||
private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException { | ||
String question = parameters.get(INPUT_FIELD); | ||
if (StringUtils.isBlank(question)) { | ||
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); | ||
} | ||
|
||
String query = getQueryBody(question); | ||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
searchSourceBuilder.fetchSource(sourceFields, null); | ||
searchSourceBuilder.size(docSize); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
return searchRequest; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
SearchRequest searchRequest; | ||
try { | ||
searchRequest = buildSearchRequest(parameters); | ||
} catch (Exception e) { | ||
log.error("Failed to build search request.", e); | ||
listener.onFailure(e); | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return; | ||
} | ||
|
||
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (int i = 0; i < hits.length; i++) { | ||
SearchHit hit = hits[i]; | ||
Map<String, Object> docContent = new HashMap<>(); | ||
docContent.put("_index", hit.getIndex()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any specific reason to extract these fields instead of deserialize the whole searchResponse as result? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/opensearch-project/OpenSearch/blob/2c8ee1947b55a1cc5bc1a114b82e3a3b8a99851e/server/src/main/java/org/opensearch/search/SearchHit.java#L616 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The result looks like in jsonl format, is this as expected? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this result will be treated as a norm string and be sent to LLM. |
||
docContent.put("_id", hit.getId()); | ||
docContent.put("_score", hit.getScore()); | ||
docContent.put("_source", hit.getSourceAsMap()); | ||
contextBuilder.append(gson.toJson(docContent)).append("\n"); | ||
} | ||
listener.onResponse((T) contextBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) "Can not get any match from search result."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a customer facing response? If yes, how about: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on my mental model the retrieval tools are used for RAG, so their response will be parsed by LLM instead of human. We don't have a tool-retry logic now, the "refine your search terms" may confuse LLM |
||
} | ||
}, e -> { | ||
log.error("Failed to search index.", e); | ||
listener.onFailure(e); | ||
}); | ||
client.search(searchRequest, actionListener); | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters != null && parameters.size() > 0 && !StringUtils.isBlank(parameters.get("input")); | ||
} | ||
|
||
protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> { | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* This tool supports neural_sparse search with sparse encoding models and rank_features field. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
@ToolAnnotation(NeuralSparseSearchTool.TYPE) | ||
public class NeuralSparseSearchTool extends AbstractRetrieverTool { | ||
public static final String TYPE = "NeuralSparseSearchTool"; | ||
public static final String MODEL_ID_FIELD = "model_id"; | ||
public static final String EMBEDDING_FIELD = "embedding_field"; | ||
|
||
private String name = TYPE; | ||
private String modelId; | ||
private String embeddingField; | ||
|
||
@Builder | ||
public NeuralSparseSearchTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String embeddingField, | ||
String[] sourceFields, | ||
Integer docSize, | ||
String modelId | ||
) { | ||
super(client, xContentRegistry, index, sourceFields, docSize); | ||
this.modelId = modelId; | ||
this.embeddingField = embeddingField; | ||
} | ||
|
||
@Override | ||
protected String getQueryBody(String queryText) { | ||
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { | ||
throw new IllegalArgumentException( | ||
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." | ||
); | ||
} | ||
return "{\"query\":{\"neural_sparse\":{\"" | ||
+ embeddingField | ||
+ "\":{\"query_text\":\"" | ||
+ queryText | ||
+ "\",\"model_id\":\"" | ||
+ modelId | ||
+ "\"}}}" | ||
+ " }"; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> { | ||
private static Factory INSTANCE; | ||
|
||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (NeuralSparseSearchTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
@Override | ||
public NeuralSparseSearchTool create(Map<String, Object> params) { | ||
String index = (String) params.get(INDEX_FIELD); | ||
String embeddingField = (String) params.get(EMBEDDING_FIELD); | ||
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); | ||
String modelId = (String) params.get(MODEL_ID_FIELD); | ||
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; | ||
return NeuralSparseSearchTool | ||
.builder() | ||
.client(client) | ||
.xContentRegistry(xContentRegistry) | ||
.index(index) | ||
.embeddingField(embeddingField) | ||
.sourceFields(sourceFields) | ||
.modelId(modelId) | ||
.docSize(docSize) | ||
.build(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what will be the current version in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's 5.5.0 in this case
https://github.com/opensearch-project/OpenSearch/blob/f92f846a1f9b30a055dde846fd12d987a511723a/buildSrc/version.properties#L58