diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index f9e1dda89e..cecb3d7816 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -75,6 +75,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index ce25a1b944..5faf833e89 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -124,7 +124,11 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener { sdkClient .putDataObjectAsync( - new PutDataObjectRequest.Builder().index(ML_MODEL_GROUP_INDEX).dataObject(mlModelGroup).build(), + new PutDataObjectRequest.Builder() + .tenantId(mlModelGroup.getTenantId()) + .index(ML_MODEL_GROUP_INDEX) + .dataObject(mlModelGroup) + .build(), client.threadPool().executor(GENERAL_THREAD_POOL) ) .whenComplete((r, throwable) -> { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index c0f5197d6d..a5ddd05855 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -15,11 +15,14 @@ import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; +import java.util.stream.Collectors; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; @@ -249,7 +252,14 @@ public CompletionStage deleteDataObjectAsync(DeleteDat */ @Override public CompletionStage searchDataObjectAsync(SearchDataObjectRequest request, Executor executor) { - return this.remoteClusterIndicesClient.searchDataObjectAsync(request, executor); + List indices = Arrays.stream(request.indices()).map(this::getTableName).collect(Collectors.toList()); + + SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest( + indices.toArray(new String[0]), + request.tenantId(), + request.searchSourceBuilder() + ); + return this.remoteClusterIndicesClient.searchDataObjectAsync(searchDataObjectRequest, executor); } private String getTableName(String index) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index 36193fdff6..2016ac2409 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -19,9 +19,11 @@ import java.util.concurrent.TimeUnit; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; @@ -50,6 +52,7 @@ import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.sdk.GetDataObjectRequest; import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; @@ -60,6 +63,8 @@ public class GetConnectorTransportActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; + private static final String TENANT_ID = "tenant_id"; + private static final TestThreadPool testThreadPool = new TestThreadPool( GetConnectorTransportActionTests.class.getName(), new ScalingExecutorBuilder( @@ -104,12 +109,15 @@ public class GetConnectorTransportActionTests extends OpenSearchTestCase { @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Captor + private ArgumentCaptor getDataObjectRequestArgumentCaptor; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); - mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).build(); + mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(TENANT_ID).build(); when(getResponse.getId()).thenReturn(CONNECTOR_ID); when(getResponse.getSourceAsString()).thenReturn("{}"); when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); @@ -198,13 +206,16 @@ public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, I ActionListener listener = invocation.getArgument(5); listener.onResponse(httpConnector); return null; - }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), + getDataObjectRequestArgumentCaptor.capture(), any(), any()); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); + Assert.assertEquals(tenantId, getDataObjectRequestArgumentCaptor.getValue().tenantId()); + Assert.assertEquals(CONNECTOR_ID, getDataObjectRequestArgumentCaptor.getValue().id()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConnectorGetResponse.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertEquals(tenantId, argumentCaptor.getValue().getMlConnector().getTenantId()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index b857f0199f..254162e337 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -26,9 +26,12 @@ import java.util.concurrent.TimeUnit; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.LatchedActionListener; @@ -60,6 +63,7 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; @@ -74,6 +78,7 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; + private static final String TENANT_ID = "tenant_id"; private static TestThreadPool testThreadPool = new TestThreadPool( TransportCreateConnectorActionTests.class.getName(), @@ -133,6 +138,9 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Captor + private ArgumentCaptor putDataObjectRequestArgumentCaptor; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @@ -140,7 +148,7 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { public void setup() { MockitoAnnotations.openMocks(this); - sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); + sdkClient = Mockito.spy(new LocalClusterIndicesClient(client, xContentRegistry)); indexResponse = new IndexResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); settings = Settings @@ -304,10 +312,13 @@ public void test_execute_connectorAccessControlEnabled_success() throws Interrup CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + request.getMlCreateConnectorInput().setTenantId(TENANT_ID); action.doExecute(task, request, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); + verify(sdkClient).putDataObjectAsync(putDataObjectRequestArgumentCaptor.capture(), Mockito.any()); + Assert.assertEquals(TENANT_ID, putDataObjectRequestArgumentCaptor.getValue().tenantId()); } public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_defaultToPrivate() throws InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 0ca07a27de..2059d7cbca 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -22,12 +22,15 @@ import org.apache.lucene.search.TotalHits; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.get.GetRequest; @@ -63,6 +66,7 @@ import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -73,6 +77,8 @@ import org.opensearch.threadpool.ThreadPool; public class MLModelGroupManagerTests extends OpenSearchTestCase { + private static final String TENANT_ID = "tenant_id"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -112,6 +118,9 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { @Mock NamedXContentRegistry xContentRegistry; + @Captor + ArgumentCaptor putDataObjectRequestArgumentCaptor; + private final List backendRoles = Arrays.asList("IT", "HR"); private static TestThreadPool testThreadPool = new TestThreadPool( @@ -129,7 +138,7 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); - sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); + sdkClient = Mockito.spy(new LocalClusterIndicesClient(client, xContentRegistry)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); mlModelGroupManager = new MLModelGroupManager( @@ -218,6 +227,8 @@ public void test_SuccessPublic() throws InterruptedException { latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); verify(actionListener).onResponse(argumentCaptor.capture()); + verify(sdkClient).putDataObjectAsync(putDataObjectRequestArgumentCaptor.capture(), Mockito.any()); + Assert.assertEquals(TENANT_ID, putDataObjectRequestArgumentCaptor.getValue().tenantId()); } @Test @@ -535,6 +546,7 @@ private MLRegisterModelGroupInput prepareRequest(List backendRoles, Acce .backendRoles(backendRoles) .modelAccessMode(modelAccessMode) .isAddAllBackendRoles(isAddAllBackendRoles) + .tenantId(TENANT_ID) .build(); } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index ebb945e60c..338076da02 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -55,6 +55,7 @@ import org.opensearch.sdk.SearchDataObjectResponse; import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.sdk.UpdateDataObjectResponse; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -78,6 +79,8 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase { private static final String TEST_ID = "123"; private static final String TENANT_ID = "TEST_TENANT_ID"; private static final String TEST_INDEX = "test_index"; + private static final String TEST_INDEX_2 = "test_index_2"; + private static final String TEST_SYSTEM_INDEX = ".test_index"; private SdkClient sdkClient; @Mock @@ -92,6 +95,8 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase { private ArgumentCaptor deleteItemRequestArgumentCaptor; @Captor private ArgumentCaptor updateItemRequestArgumentCaptor; + @Captor + private ArgumentCaptor searchDataObjectRequestArgumentCaptor; private TestDataObject testDataObject; private static TestThreadPool testThreadPool = new TestThreadPool( @@ -408,17 +413,39 @@ public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() { @Test public void searchDataObjectAsync_HappyCase() { + SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource(); + SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder() + .indices(TEST_INDEX, TEST_INDEX_2) + .tenantId(TENANT_ID) + .searchSourceBuilder(searchSourceBuilder) + .build(); + CompletionStage searchDataObjectResponse = Mockito.mock(CompletionStage.class); + Mockito.when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.any(), Mockito.any())).thenReturn(searchDataObjectResponse); + CompletionStage searchResponse = sdkClient.searchDataObjectAsync(searchDataObjectRequest); + + assertEquals(searchDataObjectResponse, searchResponse); + Mockito.verify(remoteClusterIndicesClient).searchDataObjectAsync(searchDataObjectRequestArgumentCaptor.capture(), Mockito.any()); + Assert.assertEquals(TENANT_ID, searchDataObjectRequestArgumentCaptor.getValue().tenantId()); + Assert.assertEquals(TEST_INDEX, searchDataObjectRequestArgumentCaptor.getValue().indices()[0]); + Assert.assertEquals(TEST_INDEX_2, searchDataObjectRequestArgumentCaptor.getValue().indices()[1]); + Assert.assertEquals(searchSourceBuilder, searchDataObjectRequestArgumentCaptor.getValue().searchSourceBuilder()); + } + + @Test + public void searchDataObjectAsync_SystemIndex() { + SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource(); SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder() - .indices(TEST_INDEX) + .indices(TEST_SYSTEM_INDEX) .tenantId(TENANT_ID) + .searchSourceBuilder(searchSourceBuilder) .build(); CompletionStage searchDataObjectResponse = Mockito.mock(CompletionStage.class); - Mockito - .when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.eq(searchDataObjectRequest), Mockito.any())) - .thenReturn(searchDataObjectResponse); + Mockito.when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.any(), Mockito.any())).thenReturn(searchDataObjectResponse); CompletionStage searchResponse = sdkClient.searchDataObjectAsync(searchDataObjectRequest); assertEquals(searchDataObjectResponse, searchResponse); + Mockito.verify(remoteClusterIndicesClient).searchDataObjectAsync(searchDataObjectRequestArgumentCaptor.capture(), Mockito.any()); + Assert.assertEquals("test_index", searchDataObjectRequestArgumentCaptor.getValue().indices()[0]); } }