Skip to content

Commit

Permalink
add multi tenacy support
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Jan 24, 2025
1 parent 76f1cfc commit 5144fdf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
Expand Down Expand Up @@ -165,6 +166,7 @@ public CreateAnomalyDetectorTool(Client client, String modelId, String modelType
*/
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tenantId = parameters.get(TENANT_ID_FIELD);
Map<String, String> enrichedParameters = enrichParameters(parameters);
String indexName = enrichedParameters.get("index");
if (Strings.isNullOrEmpty(indexName)) {
Expand Down Expand Up @@ -228,7 +230,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(),
null
null,
tenantId
);

client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -172,6 +174,7 @@ public PPLTool(
@SuppressWarnings("unchecked")
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tenantId = parameters.get(TENANT_ID_FIELD);
extractFromChatParameters(parameters);
String indexName = getIndexNameFromParameters(parameters);
if (StringUtils.isBlank(indexName)) {
Expand Down Expand Up @@ -209,7 +212,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(),
null
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.apache.commons.lang3.StringEscapeUtils.escapeJson;
import static org.opensearch.agent.tools.AbstractRetrieverTool.*;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

Expand Down Expand Up @@ -94,6 +95,8 @@ public Object parse(Object o) {
}

public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tenantId = parameters.get(TENANT_ID_FIELD);

String input = null;

if (!this.validate(parameters)) {
Expand Down Expand Up @@ -145,7 +148,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null);
ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null, tenantId);

client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput();
Expand Down

0 comments on commit 5144fdf

Please sign in to comment.