Skip to content

Commit

Permalink
Set tenant ID while creating connector and model group
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri committed Jul 2, 2024
1 parent 1eb1460 commit 0c5c7fd
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConn
GetDataObjectRequest getDataObjectRequest = new GetDataObjectRequest.Builder()
.index(ML_CONNECTOR_INDEX)
.id(connectorId)
.tenantId(tenantId)
.fetchSourceContext(fetchSourceContext)
.build();
User user = RestActionUtils.getUserContext(client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ private void indexConnector(Connector connector, ActionListener<MLCreateConnecto
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
sdkClient
.putDataObjectAsync(
new PutDataObjectRequest.Builder().index(ML_CONNECTOR_INDEX).dataObject(connector).build(),
new PutDataObjectRequest.Builder()
.tenantId(connector.getTenantId())
.index(ML_CONNECTOR_INDEX)
.dataObject(connector)
.build(),
client.threadPool().executor(GENERAL_THREAD_POOL)
)
.whenComplete((r, throwable) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
mlIndicesHandler.initModelGroupIndexIfAbsent(ActionListener.wrap(res -> {
sdkClient
.putDataObjectAsync(
new PutDataObjectRequest.Builder().index(ML_MODEL_GROUP_INDEX).dataObject(mlModelGroup).build(),
new PutDataObjectRequest.Builder()
.tenantId(mlModelGroup.getTenantId())
.index(ML_MODEL_GROUP_INDEX)
.dataObject(mlModelGroup)
.build(),
client.threadPool().executor(GENERAL_THREAD_POOL)
)
.whenComplete((r, throwable) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.get.GetResponse;
Expand Down Expand Up @@ -249,7 +252,14 @@ public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDat
*/
@Override
public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(SearchDataObjectRequest request, Executor executor) {
return this.remoteClusterIndicesClient.searchDataObjectAsync(request, executor);
List<String> indices = Arrays.stream(request.indices()).map(this::getTableName).collect(Collectors.toList());

SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest(
indices.toArray(new String[0]),
request.tenantId(),
request.searchSourceBuilder()
);
return this.remoteClusterIndicesClient.searchDataObjectAsync(searchDataObjectRequest, executor);
}

private String getTableName(String index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import java.util.concurrent.TimeUnit;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
Expand Down Expand Up @@ -50,6 +52,7 @@
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.sdkclient.LocalClusterIndicesClient;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.sdk.GetDataObjectRequest;
import org.opensearch.sdk.SdkClient;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ScalingExecutorBuilder;
Expand All @@ -60,6 +63,8 @@
public class GetConnectorTransportActionTests extends OpenSearchTestCase {
private static final String CONNECTOR_ID = "connector_id";

private static final String TENANT_ID = "tenant_id";

private static final TestThreadPool testThreadPool = new TestThreadPool(
GetConnectorTransportActionTests.class.getName(),
new ScalingExecutorBuilder(
Expand Down Expand Up @@ -104,12 +109,15 @@ public class GetConnectorTransportActionTests extends OpenSearchTestCase {
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Captor
private ArgumentCaptor<GetDataObjectRequest> getDataObjectRequestArgumentCaptor;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);

sdkClient = new LocalClusterIndicesClient(client, xContentRegistry);
mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).build();
mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(TENANT_ID).build();
when(getResponse.getId()).thenReturn(CONNECTOR_ID);
when(getResponse.getSourceAsString()).thenReturn("{}");
when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false);
Expand Down Expand Up @@ -198,13 +206,16 @@ public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, I
ActionListener<Connector> listener = invocation.getArgument(5);
listener.onResponse(httpConnector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any());
}).when(connectorAccessControlHelper).getConnector(any(), any(), any(),
getDataObjectRequestArgumentCaptor.capture(), any(), any());

CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<MLConnectorGetResponse> latchedActionListener = new LatchedActionListener<>(actionListener, latch);
getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener);
latch.await(500, TimeUnit.MILLISECONDS);

Assert.assertEquals(tenantId, getDataObjectRequestArgumentCaptor.getValue().tenantId());
Assert.assertEquals(CONNECTOR_ID, getDataObjectRequestArgumentCaptor.getValue().id());
ArgumentCaptor<MLConnectorGetResponse> argumentCaptor = ArgumentCaptor.forClass(MLConnectorGetResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
assertEquals(tenantId, argumentCaptor.getValue().getMlConnector().getTenantId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
import java.util.concurrent.TimeUnit;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.LatchedActionListener;
Expand Down Expand Up @@ -60,6 +63,7 @@
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.sdkclient.LocalClusterIndicesClient;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.sdk.PutDataObjectRequest;
import org.opensearch.sdk.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -74,6 +78,7 @@
public class TransportCreateConnectorActionTests extends OpenSearchTestCase {

private static final String CONNECTOR_ID = "connector_id";
private static final String TENANT_ID = "tenant_id";

private static TestThreadPool testThreadPool = new TestThreadPool(
TransportCreateConnectorActionTests.class.getName(),
Expand Down Expand Up @@ -133,14 +138,17 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase {
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Captor
private ArgumentCaptor<PutDataObjectRequest> putDataObjectRequestArgumentCaptor;

private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList
.of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$");

@Before
public void setup() {
MockitoAnnotations.openMocks(this);

sdkClient = new LocalClusterIndicesClient(client, xContentRegistry);
sdkClient = Mockito.spy(new LocalClusterIndicesClient(client, xContentRegistry));
indexResponse = new IndexResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true);

settings = Settings
Expand Down Expand Up @@ -304,10 +312,13 @@ public void test_execute_connectorAccessControlEnabled_success() throws Interrup

CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<MLCreateConnectorResponse> latchedActionListener = new LatchedActionListener<>(actionListener, latch);
request.getMlCreateConnectorInput().setTenantId(TENANT_ID);
action.doExecute(task, request, latchedActionListener);
latch.await(500, TimeUnit.MILLISECONDS);

verify(actionListener).onResponse(any(MLCreateConnectorResponse.class));
verify(sdkClient).putDataObjectAsync(putDataObjectRequestArgumentCaptor.capture(), Mockito.any());
Assert.assertEquals(TENANT_ID, putDataObjectRequestArgumentCaptor.getValue().tenantId());
}

public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_defaultToPrivate() throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@

import org.apache.lucene.search.TotalHits;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -63,6 +66,7 @@
import org.opensearch.ml.sdkclient.LocalClusterIndicesClient;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.sdk.PutDataObjectRequest;
import org.opensearch.sdk.SdkClient;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
Expand All @@ -73,6 +77,8 @@
import org.opensearch.threadpool.ThreadPool;

public class MLModelGroupManagerTests extends OpenSearchTestCase {
private static final String TENANT_ID = "tenant_id";

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -112,6 +118,9 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase {
@Mock
NamedXContentRegistry xContentRegistry;

@Captor
ArgumentCaptor<PutDataObjectRequest> putDataObjectRequestArgumentCaptor;

private final List<String> backendRoles = Arrays.asList("IT", "HR");

private static TestThreadPool testThreadPool = new TestThreadPool(
Expand All @@ -129,7 +138,7 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase {
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);

sdkClient = new LocalClusterIndicesClient(client, xContentRegistry);
sdkClient = Mockito.spy(new LocalClusterIndicesClient(client, xContentRegistry));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
mlModelGroupManager = new MLModelGroupManager(
Expand Down Expand Up @@ -218,6 +227,8 @@ public void test_SuccessPublic() throws InterruptedException {
latch.await(500, TimeUnit.MILLISECONDS);
ArgumentCaptor<String> argumentCaptor = ArgumentCaptor.forClass(String.class);
verify(actionListener).onResponse(argumentCaptor.capture());
verify(sdkClient).putDataObjectAsync(putDataObjectRequestArgumentCaptor.capture(), Mockito.any());
Assert.assertEquals(TENANT_ID, putDataObjectRequestArgumentCaptor.getValue().tenantId());
}

@Test
Expand Down Expand Up @@ -535,6 +546,7 @@ private MLRegisterModelGroupInput prepareRequest(List<String> backendRoles, Acce
.backendRoles(backendRoles)
.modelAccessMode(modelAccessMode)
.isAddAllBackendRoles(isAddAllBackendRoles)
.tenantId(TENANT_ID)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.opensearch.sdk.SearchDataObjectResponse;
import org.opensearch.sdk.UpdateDataObjectRequest;
import org.opensearch.sdk.UpdateDataObjectResponse;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ScalingExecutorBuilder;
import org.opensearch.threadpool.TestThreadPool;
Expand All @@ -78,6 +79,8 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase {
private static final String TEST_ID = "123";
private static final String TENANT_ID = "TEST_TENANT_ID";
private static final String TEST_INDEX = "test_index";
private static final String TEST_INDEX_2 = "test_index_2";
private static final String TEST_SYSTEM_INDEX = ".test_index";
private SdkClient sdkClient;

@Mock
Expand All @@ -92,6 +95,8 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase {
private ArgumentCaptor<DeleteItemRequest> deleteItemRequestArgumentCaptor;
@Captor
private ArgumentCaptor<UpdateItemRequest> updateItemRequestArgumentCaptor;
@Captor
private ArgumentCaptor<SearchDataObjectRequest> searchDataObjectRequestArgumentCaptor;
private TestDataObject testDataObject;

private static TestThreadPool testThreadPool = new TestThreadPool(
Expand Down Expand Up @@ -408,17 +413,39 @@ public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() {

@Test
public void searchDataObjectAsync_HappyCase() {
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource();
SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder()
.indices(TEST_INDEX, TEST_INDEX_2)
.tenantId(TENANT_ID)
.searchSourceBuilder(searchSourceBuilder)
.build();
CompletionStage<SearchDataObjectResponse> searchDataObjectResponse = Mockito.mock(CompletionStage.class);
Mockito.when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.any(), Mockito.any())).thenReturn(searchDataObjectResponse);
CompletionStage<SearchDataObjectResponse> searchResponse = sdkClient.searchDataObjectAsync(searchDataObjectRequest);

assertEquals(searchDataObjectResponse, searchResponse);
Mockito.verify(remoteClusterIndicesClient).searchDataObjectAsync(searchDataObjectRequestArgumentCaptor.capture(), Mockito.any());
Assert.assertEquals(TENANT_ID, searchDataObjectRequestArgumentCaptor.getValue().tenantId());
Assert.assertEquals(TEST_INDEX, searchDataObjectRequestArgumentCaptor.getValue().indices()[0]);
Assert.assertEquals(TEST_INDEX_2, searchDataObjectRequestArgumentCaptor.getValue().indices()[1]);
Assert.assertEquals(searchSourceBuilder, searchDataObjectRequestArgumentCaptor.getValue().searchSourceBuilder());
}

@Test
public void searchDataObjectAsync_SystemIndex() {
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource();
SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder()
.indices(TEST_INDEX)
.indices(TEST_SYSTEM_INDEX)
.tenantId(TENANT_ID)
.searchSourceBuilder(searchSourceBuilder)
.build();
CompletionStage<SearchDataObjectResponse> searchDataObjectResponse = Mockito.mock(CompletionStage.class);
Mockito
.when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.eq(searchDataObjectRequest), Mockito.any()))
.thenReturn(searchDataObjectResponse);
Mockito.when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.any(), Mockito.any())).thenReturn(searchDataObjectResponse);
CompletionStage<SearchDataObjectResponse> searchResponse = sdkClient.searchDataObjectAsync(searchDataObjectRequest);

assertEquals(searchDataObjectResponse, searchResponse);
Mockito.verify(remoteClusterIndicesClient).searchDataObjectAsync(searchDataObjectRequestArgumentCaptor.capture(), Mockito.any());
Assert.assertEquals("test_index", searchDataObjectRequestArgumentCaptor.getValue().indices()[0]);
}

}

0 comments on commit 0c5c7fd

Please sign in to comment.