-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: Add AbstractRetriverTool, VectorDBTool, NeuralSparseTools #40
Changes from 4 commits
7d1fb29
355bbe7
340a6ed
1114732
7fe33f6
fd3cf33
4311116
02f0d45
43a228f
97f92e9
de436a4
6ec9b0c
ef03c34
b1dbe84
b3a964b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,7 +82,7 @@ configurations { | |
zipArchive | ||
all { | ||
resolutionStrategy { | ||
force "org.mockito:mockito-core:5.8.0" | ||
force "org.mockito:mockito-core:${versions.mockito}" | ||
force "com.google.guava:guava:32.1.3-jre" // CVE for 31.1 | ||
force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0, forces JDK17 for spotless | ||
} | ||
|
@@ -103,12 +103,15 @@ dependencies { | |
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' | ||
compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0" | ||
compileOnly group: 'org.json', name: 'json', version: '20231013' | ||
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" | ||
compileOnly "org.opensearch:common-utils:${version}" | ||
|
||
implementation("com.google.guava:guava:32.1.3-jre") | ||
implementation group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}" | ||
implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"]) | ||
compileOnly "org.opensearch:common-utils:${version}" | ||
|
||
testImplementation "org.opensearch.test:framework:${opensearch_version}" | ||
testImplementation "org.mockito:mockito-core:5.8.0" | ||
testImplementation "net.bytebuddy:byte-buddy-agent:${versions.bytebuddy}" | ||
testImplementation "org.mockito:mockito-core:${versions.mockito}" | ||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1' | ||
testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0' | ||
testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" | ||
|
@@ -118,6 +121,7 @@ dependencies { | |
|
||
// ZipArchive dependencies used for integration tests | ||
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" | ||
zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
task extractSqlJar(type: Copy) { | ||
|
@@ -137,12 +141,13 @@ testingConventions.enabled = false | |
thirdPartyAudit.enabled = false | ||
|
||
test { | ||
useJUnitPlatform() | ||
testLogging { | ||
exceptionFormat "full" | ||
events "skipped", "passed", "failed" // "started" | ||
showStandardStreams true | ||
} | ||
include '**/*Tests.class' | ||
systemProperty 'tests.security.manager', 'false' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a comment what is this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns off the security manager to make it easy for writing test case. It is used in many build script of Opensearch plugins. However we don't have such test cases currently. So I just removed this line and we can add it back when we need it. |
||
} | ||
|
||
spotless { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Abstract tool supports search paradigms in neural-search plugin. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
public abstract class AbstractRetrieverTool implements Tool { | ||
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; | ||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String SOURCE_FIELD = "source_field"; | ||
public static final String DOC_SIZE_FIELD = "doc_size"; | ||
public static final int DEFAULT_DOC_SIZE = 2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to control the matching docs size? I remember it's 10 by default in search API. Are we giving 2 randomly or there's a reason behind this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This number is inherited from origin VectorDBTool. I guess the reason maybe we want to limit the context size for LLM. I'm open for changing this number There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's confirm and add a comment for this field. I'm also curious to know more about this field. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The origin VectorDBTool was introduced from the agent framework POC commit. opensearch-project/ml-commons@1b85cff#diff-b353ce1eda5b942809ea500dadcfda5769504edd8fa17fc174d498937097c69eR67 Hi @ylwu-amzn, do you know why the default doc size is 2? Do we set this for context length? By the way I don't think this is a blocker issue for this PR. This parameter is configurable, and we can alter its value after our e2e test. |
||
|
||
protected String description = DEFAULT_DESCRIPTION; | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
protected String index; | ||
protected String[] sourceFields; | ||
protected Integer docSize; | ||
protected String version; | ||
|
||
protected AbstractRetrieverTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String[] sourceFields, | ||
Integer docSize | ||
) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.index = index; | ||
this.sourceFields = sourceFields; | ||
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; | ||
} | ||
|
||
protected abstract String getQueryBody(String queryText); | ||
|
||
private <T> SearchRequest buildSearchRequest(Map<String, String> parameters, ActionListener<T> listener) throws IOException { | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
String question = parameters.get(INPUT_FIELD); | ||
if (StringUtils.isBlank(question)) { | ||
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); | ||
} | ||
|
||
String query = getQueryBody(question); | ||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
searchSourceBuilder.fetchSource(sourceFields, null); | ||
searchSourceBuilder.size(docSize); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
return searchRequest; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
SearchRequest searchRequest; | ||
try { | ||
searchRequest = buildSearchRequest(parameters, listener); | ||
} catch (Exception e) { | ||
listener.onFailure(e); | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return; | ||
} | ||
|
||
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (int i = 0; i < hits.length; i++) { | ||
SearchHit hit = hits[i]; | ||
Map<String, Object> docContent = new HashMap<>(); | ||
docContent.put("_index", hit.getIndex()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any specific reason to extract these fields instead of deserialize the whole searchResponse as result? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/opensearch-project/OpenSearch/blob/2c8ee1947b55a1cc5bc1a114b82e3a3b8a99851e/server/src/main/java/org/opensearch/search/SearchHit.java#L616 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The result looks like in jsonl format, is this as expected? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this result will be treated as a norm string and be sent to LLM. |
||
docContent.put("_id", hit.getId()); | ||
docContent.put("_score", hit.getScore()); | ||
docContent.put("_source", hit.getSourceAsMap()); | ||
contextBuilder.append(gson.toJson(docContent)).append("\n"); | ||
} | ||
listener.onResponse((T) contextBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) ""); | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
}, e -> { | ||
log.error("Failed to search index.", e); | ||
listener.onFailure(e); | ||
}); | ||
client.search(searchRequest, actionListener); | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
if (parameters == null || parameters.size() == 0) { | ||
return false; | ||
} | ||
String question = parameters.get("input"); | ||
return !StringUtils.isBlank(question); | ||
zhichao-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> { | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* This tool supports neural_sparse search with sparse encoding models and rank_features field. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
@ToolAnnotation(NeuralSparseSearchTool.TYPE) | ||
public class NeuralSparseSearchTool extends AbstractRetrieverTool { | ||
public static final String TYPE = "NeuralSparseSearchTool"; | ||
public static final String MODEL_ID_FIELD = "model_id"; | ||
public static final String EMBEDDING_FIELD = "embedding_field"; | ||
|
||
private String name = TYPE; | ||
private String modelId; | ||
private String embeddingField; | ||
|
||
@Builder | ||
public NeuralSparseSearchTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String embeddingField, | ||
String[] sourceFields, | ||
Integer docSize, | ||
String modelId | ||
) { | ||
super(client, xContentRegistry, index, sourceFields, docSize); | ||
this.modelId = modelId; | ||
this.embeddingField = embeddingField; | ||
} | ||
|
||
@Override | ||
protected String getQueryBody(String queryText) { | ||
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { | ||
throw new IllegalArgumentException( | ||
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." | ||
); | ||
} | ||
return "{\"query\":{\"neural_sparse\":{\"" | ||
+ embeddingField | ||
+ "\":{\"query_text\":\"" | ||
+ queryText | ||
+ "\",\"model_id\":\"" | ||
+ modelId | ||
+ "\"}}}" | ||
+ " }"; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> { | ||
private static Factory INSTANCE; | ||
|
||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (NeuralSparseSearchTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
@Override | ||
public NeuralSparseSearchTool create(Map<String, Object> params) { | ||
String index = (String) params.get(INDEX_FIELD); | ||
String embeddingField = (String) params.get(EMBEDDING_FIELD); | ||
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); | ||
String modelId = (String) params.get(MODEL_ID_FIELD); | ||
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; | ||
return NeuralSparseSearchTool | ||
.builder() | ||
.client(client) | ||
.xContentRegistry(xContentRegistry) | ||
.index(index) | ||
.embeddingField(embeddingField) | ||
.sourceFields(sourceFields) | ||
.modelId(modelId) | ||
.docSize(docSize) | ||
.build(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what will be the current version in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's 5.5.0 in this case
https://github.com/opensearch-project/OpenSearch/blob/f92f846a1f9b30a055dde846fd12d987a511723a/buildSrc/version.properties#L58