-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from zhichao-aws/SearchTools
feature: Add AbstractRetriverTool, VectorDBTool, NeuralSparseTools (cherry picked from commit c088f77)
- Loading branch information
1 parent
c931029
commit 07ffd9c
Showing
10 changed files
with
834 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Abstract tool supports search paradigms in neural-search plugin. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
public abstract class AbstractRetrieverTool implements Tool { | ||
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; | ||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String SOURCE_FIELD = "source_field"; | ||
public static final String DOC_SIZE_FIELD = "doc_size"; | ||
public static final int DEFAULT_DOC_SIZE = 2; | ||
|
||
protected String description = DEFAULT_DESCRIPTION; | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
protected String index; | ||
protected String[] sourceFields; | ||
protected Integer docSize; | ||
protected String version; | ||
|
||
protected AbstractRetrieverTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String[] sourceFields, | ||
Integer docSize | ||
) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.index = index; | ||
this.sourceFields = sourceFields; | ||
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; | ||
} | ||
|
||
protected abstract String getQueryBody(String queryText); | ||
|
||
private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException { | ||
String question = parameters.get(INPUT_FIELD); | ||
if (StringUtils.isBlank(question)) { | ||
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); | ||
} | ||
|
||
String query = getQueryBody(question); | ||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
searchSourceBuilder.fetchSource(sourceFields, null); | ||
searchSourceBuilder.size(docSize); | ||
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); | ||
return searchRequest; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
SearchRequest searchRequest; | ||
try { | ||
searchRequest = buildSearchRequest(parameters); | ||
} catch (Exception e) { | ||
log.error("Failed to build search request.", e); | ||
listener.onFailure(e); | ||
return; | ||
} | ||
|
||
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (int i = 0; i < hits.length; i++) { | ||
SearchHit hit = hits[i]; | ||
Map<String, Object> docContent = new HashMap<>(); | ||
docContent.put("_index", hit.getIndex()); | ||
docContent.put("_id", hit.getId()); | ||
docContent.put("_score", hit.getScore()); | ||
docContent.put("_source", hit.getSourceAsMap()); | ||
contextBuilder.append(gson.toJson(docContent)).append("\n"); | ||
} | ||
listener.onResponse((T) contextBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) "Can not get any match from search result."); | ||
} | ||
}, e -> { | ||
log.error("Failed to search index.", e); | ||
listener.onFailure(e); | ||
}); | ||
client.search(searchRequest, actionListener); | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters != null && parameters.size() > 0 && !StringUtils.isBlank(parameters.get("input")); | ||
} | ||
|
||
protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> { | ||
protected Client client; | ||
protected NamedXContentRegistry xContentRegistry; | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
110 changes: 110 additions & 0 deletions
110
src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.gson; | ||
|
||
import java.util.Map; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* This tool supports neural_sparse search with sparse encoding models and rank_features field. | ||
*/ | ||
@Log4j2 | ||
@Getter | ||
@Setter | ||
@ToolAnnotation(NeuralSparseSearchTool.TYPE) | ||
public class NeuralSparseSearchTool extends AbstractRetrieverTool { | ||
public static final String TYPE = "NeuralSparseSearchTool"; | ||
public static final String MODEL_ID_FIELD = "model_id"; | ||
public static final String EMBEDDING_FIELD = "embedding_field"; | ||
|
||
private String name = TYPE; | ||
private String modelId; | ||
private String embeddingField; | ||
|
||
@Builder | ||
public NeuralSparseSearchTool( | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
String index, | ||
String embeddingField, | ||
String[] sourceFields, | ||
Integer docSize, | ||
String modelId | ||
) { | ||
super(client, xContentRegistry, index, sourceFields, docSize); | ||
this.modelId = modelId; | ||
this.embeddingField = embeddingField; | ||
} | ||
|
||
@Override | ||
protected String getQueryBody(String queryText) { | ||
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { | ||
throw new IllegalArgumentException( | ||
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." | ||
); | ||
} | ||
return "{\"query\":{\"neural_sparse\":{\"" | ||
+ embeddingField | ||
+ "\":{\"query_text\":\"" | ||
+ queryText | ||
+ "\",\"model_id\":\"" | ||
+ modelId | ||
+ "\"}}}" | ||
+ " }"; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> { | ||
private static Factory INSTANCE; | ||
|
||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (NeuralSparseSearchTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
@Override | ||
public NeuralSparseSearchTool create(Map<String, Object> params) { | ||
String index = (String) params.get(INDEX_FIELD); | ||
String embeddingField = (String) params.get(EMBEDDING_FIELD); | ||
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); | ||
String modelId = (String) params.get(MODEL_ID_FIELD); | ||
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; | ||
return NeuralSparseSearchTool | ||
.builder() | ||
.client(client) | ||
.xContentRegistry(xContentRegistry) | ||
.index(index) | ||
.embeddingField(embeddingField) | ||
.sourceFields(sourceFields) | ||
.modelId(modelId) | ||
.docSize(docSize) | ||
.build(); | ||
} | ||
} | ||
} |
Oops, something went wrong.