Skip to content

Commit

Permalink
Simplify instantiating Data Object Request/Response builders (opensea…
Browse files Browse the repository at this point in the history
…rch-project#2608)

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored and arjunkumargiri committed Jul 9, 2024
1 parent 8315ee3 commit 15221b8
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 30 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
pull_request_target:
types: [opened, synchronize, reopened]

env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true

permissions:
id-token: write
contents: read
Expand Down Expand Up @@ -40,7 +43,7 @@ jobs:

steps:
- name: Setup Java ${{ matrix.java }}
uses: actions/setup-java@v1
uses: actions/setup-java@v3
with:
java-version: ${{ matrix.java }}

Expand Down Expand Up @@ -145,8 +148,8 @@ jobs:
- name: Generate Password For Admin
id: genpass
run: |
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
- name: Run Docker Image
if: env.imagePresent == 'true'
run: |
Expand Down
1 change: 0 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ dependencies {
implementation "software.amazon.awssdk:third-party-jackson-core:2.25.40"
implementation("software.amazon.awssdk:url-connection-client:2.25.40")
implementation("software.amazon.awssdk:utils:2.25.40")
implementation("software.amazon.awssdk:apache-client:2.25.40")


configurations.all {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLCreateControllerResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Boolean isHidden = mlModel.getIsHidden();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
modelAccessControlHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCont
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLController controller = MLController.parse(parser);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
modelAccessControlHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<UpdateResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Boolean isHidden = mlModel.getIsHidden();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ private void deployModel(
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE && !mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
if (algorithm == FunctionName.REMOTE) {
if (!mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
return;
}
mlTaskManager.add(mlTask, eligibleNodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
Expand Down
17 changes: 11 additions & 6 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
Expand Down Expand Up @@ -370,7 +370,8 @@ public void registerMLRemoteModel(
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

String modelGroupId = mlRegisterModelInput.getModelGroupId();
GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest.builder()
GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.tenantId(mlRegisterModelInput.getTenantId())
.id(modelGroupId)
Expand All @@ -395,7 +396,8 @@ public void registerMLRemoteModel(
*/
modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion);
modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest.builder()
UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.tenantId(mlRegisterModelInput.getTenantId())
Expand Down Expand Up @@ -589,7 +591,8 @@ private void indexRemoteModel(
.tenantId(registerModelInput.getTenantId())
.build();

PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest.builder()
PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest
.builder()
.index(ML_MODEL_INDEX)
.id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null)
.tenantId(registerModelInput.getTenantId())
Expand Down Expand Up @@ -1635,7 +1638,8 @@ public void getModel(String modelId, String tenantId, ActionListener<MLModel> li
* @param listener action listener
*/
public void getModel(String modelId, String tenantId, String[] includes, String[] excludes, ActionListener<MLModel> listener) {
GetDataObjectRequest getRequest = GetDataObjectRequest.builder()
GetDataObjectRequest getRequest = GetDataObjectRequest
.builder()
.index(ML_MODEL_INDEX)
.id(modelId)
.tenantId(tenantId)
Expand Down Expand Up @@ -1706,7 +1710,8 @@ public void getController(String modelId, ActionListener<MLController> listener)
* @param listener action listener
*/
private void getConnector(String connectorId, String tenantId, ActionListener<Connector> listener) {
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest.builder()
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
.builder()
.index(ML_CONNECTOR_INDEX)
.id(connectorId)
.tenantId(tenantId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRe
if (getResponse == null) {
return GetDataObjectResponse.builder().id(request.id()).parser(null).build();
}
return GetDataObjectResponse.builder()
return GetDataObjectResponse
.builder()
.id(getResponse.getId())
.parser(createParser(getResponse))
.source(getResponse.getSource())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
package org.opensearch.ml.sdkclient;

import org.apache.http.HttpHost;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.opensearch.OpenSearchException;
import org.opensearch.client.Client;
import org.opensearch.client.RestClient;
import org.opensearch.client.json.jackson.JacksonJsonpMapper;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.transport.aws.AwsSdk2Transport;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import org.opensearch.client.transport.rest_client.RestClientTransport;
import org.opensearch.common.inject.AbstractModule;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.sdk.SdkClient;
Expand All @@ -28,8 +30,6 @@
import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;

Expand Down Expand Up @@ -135,33 +135,24 @@ private OpenSearchClient createOpenSearchClient() {
try {
BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
// Basic http(not-s) client using RestClient.
SdkHttpClient httpClient = ApacheHttpClient.builder().build();
AwsSdk2Transport awsSdk2Transport = new AwsSdk2Transport(
httpClient,
HttpHost.create(remoteMetadataEndpoint).getHostName(),
"aoss",
Region.of(region),
AwsSdk2TransportOptions.builder().build()
);
/*RestClient restClient = RestClient
RestClient restClient = RestClient
// This HttpHost syntax works with export REMOTE_METADATA_ENDPOINT=http://127.0.0.1:9200
.builder(HttpHost.create(remoteMetadataEndpoint))
.setStrictDeprecationMode(true)
.setHttpClientConfigCallback(httpClientBuilder -> {
try {
return httpClientBuilder
.setDefaultCredentialsProvider(credentialsProvider)
.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE);
.setDefaultCredentialsProvider(credentialsProvider)
.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE);
} catch (Exception e) {
throw new OpenSearchException(e);
}
})
.build();*/
.build();
ObjectMapper objectMapper = new ObjectMapper()
.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE)
.setSerializationInclusion(JsonInclude.Include.NON_NULL);
// return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(objectMapper)));
return new OpenSearchClient(awsSdk2Transport);
return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(objectMapper)));
} catch (Exception e) {
throw new OpenSearchException(e);
}
Expand Down

0 comments on commit 15221b8

Please sign in to comment.