Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/feature/multi_tenancy' into up…
Browse files Browse the repository at this point in the history
…date_ddb

Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri committed Jul 23, 2024
2 parents 034b168 + c0b2e2e commit cc26cd4
Show file tree
Hide file tree
Showing 83 changed files with 1,609 additions and 839 deletions.
22 changes: 18 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/MLConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.time.Instant;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID;

@Getter
@EqualsAndHashCode
Expand All @@ -37,19 +38,22 @@ public class MLConfig implements ToXContentObject, Writeable {

private Configuration configuration;
private final Instant createTime;
private Instant lastUpdateTime;
private final Instant lastUpdateTime;
private final String tenantId;

@Builder(toBuilder = true)
public MLConfig(
String type,
Configuration configuration,
Instant createTime,
Instant lastUpdateTime
Instant lastUpdateTime,
String tenantId
) {
this.type = type;
this.configuration = configuration;
this.createTime = createTime;
this.lastUpdateTime = lastUpdateTime;
this.tenantId = tenantId;
}

public MLConfig(StreamInput input) throws IOException {
Expand All @@ -59,6 +63,8 @@ public MLConfig(StreamInput input) throws IOException {
}
createTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
//TODO: Check BWC later
tenantId = input.readOptionalString();
}

@Override
Expand All @@ -72,6 +78,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeOptionalInstant(createTime);
out.writeOptionalInstant(lastUpdateTime);
//TODO: check BWC later
out.writeOptionalString(tenantId);
}

@Override
Expand All @@ -89,19 +97,22 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (lastUpdateTime != null) {
builder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli());
}
if (tenantId != null) {
builder.field(TENANT_ID, tenantId);
}
return builder.endObject();
}

public static MLConfig fromStream(StreamInput in) throws IOException {
MLConfig mlConfig = new MLConfig(in);
return mlConfig;
return new MLConfig(in);
}

public static MLConfig parse(XContentParser parser) throws IOException {
String type = null;
Configuration configuration = null;
Instant createTime = null;
Instant lastUpdateTime = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -121,6 +132,8 @@ public static MLConfig parse(XContentParser parser) throws IOException {
case LAST_UPDATE_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
break;
case TENANT_ID:
tenantId = parser.textOrNull();
default:
parser.skipChildren();
break;
Expand All @@ -131,6 +144,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
.configuration(configuration)
.createTime(createTime)
.lastUpdateTime(lastUpdateTime)
.tenantId(tenantId)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
Expand All @@ -18,6 +19,7 @@
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

@EqualsAndHashCode
Expand All @@ -34,14 +36,17 @@ public class MLToolSpec implements ToXContentObject {
private String description;
private Map<String, String> parameters;
private boolean includeOutputInAgentResponse;
@Setter
private String tenantId;


@Builder(toBuilder = true)
public MLToolSpec(String type,
String name,
String description,
Map<String, String> parameters,
boolean includeOutputInAgentResponse) {
boolean includeOutputInAgentResponse,
String tenantId) {
if (type == null) {
throw new IllegalArgumentException("tool type is null");
}
Expand All @@ -50,6 +55,7 @@ public MLToolSpec(String type,
this.description = description;
this.parameters = parameters;
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
this.tenantId = tenantId;
}

public MLToolSpec(StreamInput input) throws IOException{
Expand All @@ -60,19 +66,23 @@ public MLToolSpec(StreamInput input) throws IOException{
parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
includeOutputInAgentResponse = input.readBoolean();
//TODO: add bwc later
tenantId = input.readOptionalString();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(type);
out.writeOptionalString(name);
out.writeOptionalString(description);
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
out.writeBoolean(true);
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
out.writeBoolean(false);
}
out.writeBoolean(includeOutputInAgentResponse);
//TODO: add BWC later
out.writeOptionalString(tenantId);
}

@Override
Expand All @@ -87,10 +97,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
builder.field(PARAMETERS_FIELD, parameters);
}
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
if (tenantId != null) {
builder.field(TENANT_ID, tenantId);
}
builder.endObject();
return builder;
}
Expand All @@ -101,6 +114,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
String description = null;
Map<String, String> parameters = null;
boolean includeOutputInAgentResponse = false;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -123,6 +137,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = parser.booleanValue();
break;
case TENANT_ID:
tenantId = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
Expand All @@ -134,6 +151,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
.description(description)
.parameters(parameters)
.includeOutputInAgentResponse(includeOutputInAgentResponse)
.tenantId(tenantId)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ protected Map<String, String> createPredictDecryptedHeaders(Map<String, String>
for (String key : headers.keySet()) {
decryptedHeaders.put(key, substitutor.replace(headers.get(key)));
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
for (String key : decryptedHeaders.keySet()) {
decryptedHeaders.put(key, substitutor.replace(decryptedHeaders.get(key)));
Expand Down Expand Up @@ -135,11 +135,11 @@ public void removeCredential() {
@Override
public String getPredictEndpoint(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (!predictAction.isPresent()) {
if (predictAction.isEmpty()) {
return null;
}
String predictEndpoint = predictAction.get().getUrl();
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;

import java.io.IOException;
import java.util.List;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -66,8 +67,8 @@ public interface Connector extends ToXContentObject, Writeable {

<T> T createPredictPayload(Map<String, String> parameters);

void decrypt(Function<String, String> function);
void encrypt(Function<String, String> function);
void decrypt(BiFunction<String, String, String> function, String tenantId);
void encrypt(BiFunction<String, String, String> function, String tenantId);

Connector cloneConnector();

Expand All @@ -77,7 +78,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, Function<String, String> function);
void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -281,7 +282,7 @@ public void writeTo(StreamOutput out) throws IOException {
}

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
Expand All @@ -299,7 +300,7 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}
if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) {
this.credential = updateContent.getCredential();
encrypt(function);
encrypt(function, this.tenantId);
}
if (updateContent.getActions() != null) {
this.actions = updateContent.getActions();
Expand Down Expand Up @@ -357,10 +358,10 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(Function<String, String> function) {
public void decrypt(BiFunction<String, String, String> function, String tenantId) {
Map<String, String> decrypted = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key)));
decrypted.put(key, function.apply(credential.get(key), tenantId));
}
this.decryptedCredential = decrypted;
Optional<ConnectorAction> predictAction = findPredictAction();
Expand All @@ -380,9 +381,9 @@ public Connector cloneConnector() {
}

@Override
public void encrypt(Function<String, String> function) {
public void encrypt(BiFunction<String, String, String> function, String tenantId) {
for (String key : credential.keySet()) {
String encrypted = function.apply(credential.get(key));
String encrypted = function.apply(credential.get(key), tenantId);
credential.put(key, encrypted);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;

@Setter
@Getter
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {

@Setter
private Map<String, String> parameters;

@Builder(toBuilder = true)
Expand All @@ -32,7 +32,7 @@ public RemoteInferenceInputDataSet(Map<String, String> parameters) {
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
parameters = streamInput.readMap(StreamInput::readString, StreamInput::readString);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
List<String> docs = textInputDataSet.getDocs();
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
if (docs != null && docs.size() > 0) {
if (docs != null && !docs.isEmpty()) {
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
}
if (resultFilter != null) {
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
List<String> targetResponse = resultFilter.getTargetResponse();
if (targetResponse != null && targetResponse.size() > 0) {
if (targetResponse != null && !targetResponse.isEmpty()) {
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
}
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
if (targetPositions != null && targetPositions.size() > 0) {
if (targetPositions != null && !targetPositions.isEmpty()) {
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
Expand Down
Loading

0 comments on commit cc26cd4

Please sign in to comment.