Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri committed Jul 10, 2024
1 parent 0c5df38 commit 9644144
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public Builder id(String id) {
return this;
}

/**
/**
* Add a tenant ID to this builder
* @param tenantId the tenant id
* @return the updated builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ private void deployModel(
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
if (!mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
return;
}
mlTaskManager.add(mlTask, eligibleNodeIds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ public MachineLearningPlugin(Settings settings) {

@Override
public Collection<Module> createGuiceModules() {
// TODO: SDKClientModule is initialized both in createGuiceModules and createComponents. Unify these
// approaches to prevent multiple instances of SDKClient.
return List.of(new SdkClientModule(null, null));
}

Expand Down Expand Up @@ -464,6 +466,7 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
Path dataPath = environment.dataFiles()[0];
Path configFile = environment.configFile();
// TODO: Rather than recreating SDKClientModule reuse module created as part of createGuiceModules
ModulesBuilder modules = new ModulesBuilder();
modules.add(new SdkClientModule(client, xContentRegistry));
Injector injector = modules.createInjector();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.ml.action.controller;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,32 @@ public void testDoExecuteRemoteInferenceDisabled() {
assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage());
}

public void testDoExecuteRemoteInference_MultiNodeEnabled() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE);
when(mlModel.getTenantId()).thenReturn("test_tenant");
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));
doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
IndexResponse indexResponse = mock(IndexResponse.class);
when(indexResponse.getId()).thenReturn("mockIndexId");
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class));

when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true);
ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
when(mlDeployModelRequest.getTenantId()).thenReturn("test_tenant");
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
ArgumentCaptor<MLDeployModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class);
verify(deployModelResponseListener).onResponse(argumentCaptor.capture());
assertEquals("CREATED", argumentCaptor.getValue().getStatus());
}

public void testDoExecuteLocalInferenceDisabled() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
Expand Down

0 comments on commit 9644144

Please sign in to comment.