-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Hailong Cui <[email protected]>
- Loading branch information
1 parent
5a9dbcd
commit a92f52f
Showing
5 changed files
with
264 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
src/main/java/org/opensearch/agent/tools/PainlessTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.script.Script; | ||
import org.opensearch.script.ScriptService; | ||
import org.opensearch.script.ScriptType; | ||
import org.opensearch.script.TemplateScript; | ||
|
||
import com.google.gson.Gson; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* use case for this tool will only focus on flow agent | ||
*/ | ||
@Log4j2 | ||
@Setter | ||
@Getter | ||
@ToolAnnotation(PainlessTool.TYPE) | ||
public class PainlessTool implements Tool { | ||
public static final String TYPE = "PainlessTool"; | ||
private static final String DEFAULT_DESCRIPTION = "Use this tool to execute painless script"; | ||
|
||
@Setter | ||
@Getter | ||
private String name = TYPE; | ||
|
||
@Getter | ||
private String type = TYPE; | ||
|
||
@Getter | ||
@Setter | ||
private String description = DEFAULT_DESCRIPTION; | ||
|
||
@Getter | ||
private String version; | ||
|
||
private ScriptService scriptService; | ||
@Setter | ||
private String scriptCode; | ||
|
||
public PainlessTool(ScriptService scriptEngine, String script) { | ||
this.scriptService = scriptEngine; | ||
this.scriptCode = script; | ||
} | ||
|
||
private Gson gson = new Gson(); | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
Script script = new Script(ScriptType.INLINE, "painless", scriptCode, Collections.emptyMap()); | ||
Map<String, Object> flattenedParameters = new HashMap<>(); | ||
for (Map.Entry<String, String> entry : parameters.entrySet()) { | ||
// keep original values and flatten | ||
flattenedParameters.put(entry.getKey(), entry.getValue()); | ||
// TODO default is json parser. we may support format | ||
try { | ||
String value = org.apache.commons.text.StringEscapeUtils.unescapeJson(entry.getValue()); | ||
Map<String, ?> map = gson.fromJson(value, Map.class); | ||
flattenMap(map, flattenedParameters, entry.getKey()); | ||
} catch (Throwable ignored) {} | ||
} | ||
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters); | ||
try { | ||
String result = templateScript.execute(); | ||
listener.onResponse(result == null ? (T) "" : (T) result); | ||
} catch (Exception e) { | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
private void flattenMap(Map<String, ?> map, Map<String, Object> flatMap, String prefix) { | ||
for (Map.Entry<String, ?> entry : map.entrySet()) { | ||
String key = entry.getKey(); | ||
if (prefix != null && !prefix.isEmpty()) { | ||
key = prefix + "." + entry.getKey(); | ||
} | ||
Object value = entry.getValue(); | ||
if (value instanceof Map) { | ||
flattenMap((Map<String, ?>) value, flatMap, key); | ||
} else { | ||
flatMap.put(key, value); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> map) { | ||
return true; | ||
} | ||
|
||
public static class Factory implements Tool.Factory<PainlessTool> { | ||
private ScriptService scriptService; | ||
|
||
private static PainlessTool.Factory INSTANCE; | ||
|
||
public static PainlessTool.Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (PainlessTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new PainlessTool.Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
public void init(ScriptService scriptService) { | ||
this.scriptService = scriptService; | ||
} | ||
|
||
@Override | ||
public PainlessTool create(Map<String, Object> map) { | ||
String script = (String) map.get("script"); | ||
// TODO add script non null/empty check | ||
return new PainlessTool(scriptService, script); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
|
||
@Override | ||
public String getDefaultType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getDefaultVersion() { | ||
return null; | ||
} | ||
|
||
} | ||
} |
79 changes: 79 additions & 0 deletions
79
src/test/java/org/opensearch/integTest/PainlessToolIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.integTest; | ||
|
||
import java.io.IOException; | ||
import java.net.URISyntaxException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Before; | ||
|
||
import lombok.SneakyThrows; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class PainlessToolIT extends BaseAgentToolsIT { | ||
|
||
private String registerAgentRequestBody; | ||
|
||
@Before | ||
@SneakyThrows | ||
public void setUp() { | ||
super.setUp(); | ||
registerAgentRequestBody = Files | ||
.readString( | ||
Path.of(this.getClass().getClassLoader().getResource("org/opensearch/agent/tools/register_painless_agent.json").toURI()) | ||
); | ||
} | ||
|
||
public void test_execute() { | ||
String script = "def x = new HashMap(); x.abc = '5'; return x.abc;"; | ||
String agentRequestBody = registerAgentRequestBody.replaceAll("<SCRIPT>", script); | ||
String agentId = createAgent(agentRequestBody); | ||
String agentInput = "{\"parameters\":{}}"; | ||
String result = executeAgent(agentId, agentInput); | ||
Assert.assertEquals("5", result); | ||
} | ||
|
||
public void test_execute_with_parameter() { | ||
String script = "params.x + params.y"; | ||
String agentRequestBody = registerAgentRequestBody.replaceAll("<SCRIPT>", script); | ||
String agentId = createAgent(agentRequestBody); | ||
String agentInput = "{\"parameters\":{\"x\":1,\"y\":2}}"; | ||
String result = executeAgent(agentId, agentInput); | ||
Assert.assertEquals("12", result); | ||
} | ||
|
||
public void test_execute_with_parameter2() throws URISyntaxException, IOException { | ||
String script = | ||
"return 'An example output: with ppl:<ppl>' + params.get('PPL.output.ppl') + '</ppl>, and this is ppl result: <ppl_result>' + params.get('PPL.output.executionResult') + '</ppl_result>'"; | ||
String mockPPLOutput = "return '{\\\\\"executionResult\\\\\":\\\\\"result\\\\\",\\\\\"ppl\\\\\":\\\\\"source=demo| head 1\\\\\"}'"; | ||
String registerAgentRequestBody2 = Files | ||
.readString( | ||
Path | ||
.of( | ||
this | ||
.getClass() | ||
.getClassLoader() | ||
.getResource("org/opensearch/agent/tools/register_painless_agent_with_multiple_tools.json") | ||
.toURI() | ||
) | ||
); | ||
String agentRequestBody = registerAgentRequestBody2.replaceAll("<SCRIPT1>", mockPPLOutput).replaceAll("<SCRIPT2>", script); | ||
|
||
log.info("agentRequestBody = {}", agentRequestBody); | ||
String agentId = createAgent(agentRequestBody); | ||
String agentInput = "{\"parameters\":{}}"; | ||
String result = executeAgent(agentId, agentInput); | ||
Assert | ||
.assertEquals( | ||
"An example output: with ppl:<ppl>source=demo| head 1</ppl>, and this is ppl result: <ppl_result>result</ppl_result>", | ||
result | ||
); | ||
} | ||
} |
12 changes: 12 additions & 0 deletions
12
src/test/resources/org/opensearch/agent/tools/register_painless_agent.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"name": "Test_PainlessTool", | ||
"type": "flow", | ||
"tools": [ | ||
{ | ||
"type": "PainlessTool", | ||
"parameters": { | ||
"script": "<SCRIPT>" | ||
} | ||
} | ||
] | ||
} |
19 changes: 19 additions & 0 deletions
19
...est/resources/org/opensearch/agent/tools/register_painless_agent_with_multiple_tools.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"name": "Test_PainlessTool", | ||
"type": "flow", | ||
"tools": [ | ||
{ | ||
"type": "PainlessTool", | ||
"name": "PPL", | ||
"parameters": { | ||
"script": "<SCRIPT1>" | ||
} | ||
}, | ||
{ | ||
"type": "PainlessTool", | ||
"parameters": { | ||
"script": "<SCRIPT2>" | ||
} | ||
} | ||
] | ||
} |