From 7d48a727bda5684625837eef12b1197042dd7d2e Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 30 Oct 2023 18:07:35 -0700 Subject: [PATCH 1/4] Fix Create Connector actions parsing Signed-off-by: Daniel Widdis --- .../workflow/CreateConnectorStep.java | 67 ++++++++++++++----- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index e17bf2aa0..7cf3d2898 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -16,6 +16,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -24,8 +25,10 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.CompletableFuture; @@ -84,38 +87,35 @@ public void onFailure(Exception e) { String description = null; String version = null; String protocol = null; - Map parameters = new HashMap<>(); - Map credentials = new HashMap<>(); - List actions = new ArrayList<>(); + Map parameters = Collections.emptyMap(); + Map credentials = Collections.emptyMap(); + List actions = Collections.emptyList(); for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { + for (Entry entry : workflowData.getContent().entrySet()) { switch (entry.getKey()) { case NAME_FIELD: - name = (String) content.get(NAME_FIELD); + name = (String) entry.getValue(); break; case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); + description = (String) entry.getValue(); break; case VERSION_FIELD: - version = (String) content.get(VERSION_FIELD); + version = (String) entry.getValue(); break; case PROTOCOL_FIELD: - protocol = (String) content.get(PROTOCOL_FIELD); + protocol = (String) entry.getValue(); break; case PARAMETERS_FIELD: - parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); + parameters = getParameterMap(entry.getValue()); break; case CREDENTIALS_FIELD: - credentials = (Map) content.get(CREDENTIALS_FIELD); + credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); break; case ACTIONS_FIELD: - actions = (List) content.get(ACTIONS_FIELD); + actions = getConnectorActionList(entry.getValue()); break; } - } } @@ -145,14 +145,20 @@ public String getName() { return NAME; } - private static Map getParameterMap(Map params) { + @SuppressWarnings("unchecked") + private static Map getStringToStringMap(Object map, String fieldName) { + if (map instanceof Map) { + return (Map) map; + } + throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); + } + private static Map getParameterMap(Object parameterMap) { Map parameters = new HashMap<>(); - for (String key : params.keySet()) { - String value = params.get(key); + for (Entry entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) { try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - parameters.put(key, value); + parameters.put(entry.getKey(), entry.getValue()); return null; }); } catch (PrivilegedActionException e) { @@ -162,4 +168,29 @@ private static Map getParameterMap(Map params) { return parameters; } + private static List getConnectorActionList(Object array) { + if (!(array instanceof Map[])) { + throw new IllegalArgumentException("[" + ACTIONS_FIELD + "] must be an array of key-value maps."); + } + List actions = new ArrayList<>(); + for (Map map : (Map[]) array) { + String actionType = (String) map.get(ConnectorAction.ACTION_TYPE_FIELD); + if (actionType == null) { + throw new IllegalArgumentException("[" + ConnectorAction.ACTION_TYPE_FIELD + "] is missing."); + } + @SuppressWarnings("unchecked") + ConnectorAction action = ConnectorAction.builder() + .actionType(ActionType.valueOf(actionType.toUpperCase(Locale.ROOT))) + .method((String) map.get(ConnectorAction.METHOD_FIELD)) + .url((String) map.get(ConnectorAction.URL_FIELD)) + .headers((Map) map.get(ConnectorAction.HEADERS_FIELD)) + .requestBody((String) map.get(ConnectorAction.REQUEST_BODY_FIELD)) + .preProcessFunction((String) map.get(ConnectorAction.ACTION_PRE_PROCESS_FUNCTION)) + .postProcessFunction((String) map.get(ConnectorAction.ACTION_POST_PROCESS_FUNCTION)) + .build(); + actions.add(action); + } + return actions; + } + } From 639a80a143f5258fd7cd3c212954bfe25bdcdad3 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 30 Oct 2023 18:38:56 -0700 Subject: [PATCH 2/4] Fix tests Signed-off-by: Daniel Widdis --- .../workflow/CreateConnectorStepTests.java | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index b54b2a27c..cee480bf6 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -27,7 +27,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -44,13 +43,19 @@ public void setUp() throws Exception { Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); + Map[] actions = new Map[] { + Map.ofEntries( + Map.entry(ConnectorAction.ACTION_TYPE_FIELD, ConnectorAction.ActionType.PREDICT.name()), + Map.entry(ConnectorAction.METHOD_FIELD, "post"), + Map.entry(ConnectorAction.URL_FIELD, "foo.test"), + Map.entry( + ConnectorAction.REQUEST_BODY_FIELD, + "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }" + ) + ) }; MockitoAnnotations.openMocks(this); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = "foot.test"; - inputData = new WorkflowData( Map.ofEntries( Map.entry("name", "test"), @@ -59,23 +64,9 @@ public void setUp() throws Exception { Map.entry("protocol", "test"), Map.entry("params", params), Map.entry("credentials", credentials), - Map.entry( - "actions", - List.of( - new ConnectorAction( - actionType, - method, - url, - null, - "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }", - null, - null - ) - ) - ) + Map.entry("actions", actions) ) ); - } public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { @@ -83,6 +74,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr String connectorId = "connect"; CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { @@ -104,6 +96,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr public void testCreateConnectorFailure() throws IOException { CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { From 4a3d5089af014b9fbf9a114cbf8803372987f85b Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 30 Oct 2023 19:24:04 -0700 Subject: [PATCH 3/4] Handle exceptions Signed-off-by: Daniel Widdis --- .../workflow/CreateConnectorStep.java | 70 ++++++++++--------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 7cf3d2898..533d82c1e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -91,32 +91,40 @@ public void onFailure(Exception e) { Map credentials = Collections.emptyMap(); List actions = Collections.emptyList(); - for (WorkflowData workflowData : data) { - for (Entry entry : workflowData.getContent().entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case VERSION_FIELD: - version = (String) entry.getValue(); - break; - case PROTOCOL_FIELD: - protocol = (String) entry.getValue(); - break; - case PARAMETERS_FIELD: - parameters = getParameterMap(entry.getValue()); - break; - case CREDENTIALS_FIELD: - credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); - break; - case ACTIONS_FIELD: - actions = getConnectorActionList(entry.getValue()); - break; + try { + for (WorkflowData workflowData : data) { + for (Entry entry : workflowData.getContent().entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + name = (String) entry.getValue(); + break; + case DESCRIPTION_FIELD: + description = (String) entry.getValue(); + break; + case VERSION_FIELD: + version = (String) entry.getValue(); + break; + case PROTOCOL_FIELD: + protocol = (String) entry.getValue(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(entry.getValue()); + break; + case CREDENTIALS_FIELD: + credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); + break; + case ACTIONS_FIELD: + actions = getConnectorActionList(entry.getValue()); + break; + } } } + } catch (IllegalArgumentException iae) { + createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST)); + return createConnectorFuture; + } catch (PrivilegedActionException pae) { + createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED)); + return createConnectorFuture; } if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { @@ -153,17 +161,13 @@ private static Map getStringToStringMap(Object map, String field throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); } - private static Map getParameterMap(Object parameterMap) { + private static Map getParameterMap(Object parameterMap) throws PrivilegedActionException { Map parameters = new HashMap<>(); for (Entry entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) { - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - parameters.put(entry.getKey(), entry.getValue()); - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + parameters.put(entry.getKey(), entry.getValue()); + return null; + }); } return parameters; } From 912a856d366d352b2efd329e9279214ebf2880af Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 30 Oct 2023 20:25:23 -0700 Subject: [PATCH 4/4] Fix parameter key in tests Signed-off-by: Daniel Widdis --- .../workflow/CreateConnectorStepTests.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index cee480bf6..63855f7bd 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -10,6 +10,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; @@ -58,13 +59,13 @@ public void setUp() throws Exception { inputData = new WorkflowData( Map.ofEntries( - Map.entry("name", "test"), - Map.entry("description", "description"), - Map.entry("version", "1"), - Map.entry("protocol", "test"), - Map.entry("params", params), - Map.entry("credentials", credentials), - Map.entry("actions", actions) + Map.entry(CommonValue.NAME_FIELD, "test"), + Map.entry(CommonValue.DESCRIPTION_FIELD, "description"), + Map.entry(CommonValue.VERSION_FIELD, "1"), + Map.entry(CommonValue.PROTOCOL_FIELD, "test"), + Map.entry(CommonValue.PARAMETERS_FIELD, params), + Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), + Map.entry(CommonValue.ACTIONS_FIELD, actions) ) ); }