From 2481c868e06a628b7a4577ef23e0e137a1f6d187 Mon Sep 17 00:00:00 2001 From: Kunal Kotwani Date: Fri, 29 Sep 2023 13:40:02 -0700 Subject: [PATCH] Refactor async blob read to avoid blocking calls, support non multipart calls (#10192) Signed-off-by: Kunal Kotwani Signed-off-by: Shivansh Arora --- .../repositories/s3/S3BlobContainer.java | 89 +++++++----- .../s3/S3BlobStoreContainerTests.java | 129 +++++++++++++++--- 2 files changed, 163 insertions(+), 55 deletions(-) diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index 2911a018df337..c6ae58371e15c 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -228,35 +228,50 @@ public void readBlobAsync(String blobName, ActionListener listener) try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) { final S3AsyncClient s3AsyncClient = amazonS3Reference.get().client(); final String bucketName = blobStore.bucket(); + final String blobKey = buildKey(blobName); - final GetObjectAttributesResponse blobMetadata = getBlobMetadata(s3AsyncClient, bucketName, blobName).get(); + final CompletableFuture blobMetadataFuture = getBlobMetadata(s3AsyncClient, bucketName, blobKey); - final long blobSize = blobMetadata.objectSize(); - final int numberOfParts = blobMetadata.objectParts().totalPartsCount(); - final String blobChecksum = blobMetadata.checksum().checksumCRC32(); - - final List blobPartStreams = new ArrayList<>(); - final List> blobPartInputStreamFutures = new ArrayList<>(); - // S3 multipart files use 1 to n indexing - for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, partNumber)); - } - - CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> { - if (throwable == null) { - listener.onResponse( - new ReadContext( - blobSize, - blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()), - blobChecksum - ) - ); - } else { + blobMetadataFuture.whenComplete((blobMetadata, throwable) -> { + if (throwable != null) { Exception ex = throwable.getCause() instanceof Exception ? (Exception) throwable.getCause() : new Exception(throwable.getCause()); listener.onFailure(ex); + return; + } + + final List> blobPartInputStreamFutures = new ArrayList<>(); + final long blobSize = blobMetadata.objectSize(); + final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount(); + final String blobChecksum = blobMetadata.checksum().checksumCRC32(); + + if (numberOfParts == null) { + blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); + } else { + // S3 multipart files use 1 to n indexing + for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) { + blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber)); + } } + + CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)) + .whenComplete((unused, partThrowable) -> { + if (partThrowable == null) { + listener.onResponse( + new ReadContext( + blobSize, + blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()), + blobChecksum + ) + ); + } else { + Exception ex = partThrowable.getCause() instanceof Exception + ? (Exception) partThrowable.getCause() + : new Exception(partThrowable.getCause()); + listener.onFailure(ex); + } + }); }); } catch (Exception ex) { listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex)); @@ -685,41 +700,47 @@ static Tuple numberOfMultiparts(final long totalSize, final long par * the stream and its related metadata. * @param s3AsyncClient Async client to be utilized to fetch the object part * @param bucketName Name of the S3 bucket - * @param blobName Identifier of the blob for which the parts will be fetched - * @param partNumber Part number for the blob to be retrieved + * @param blobKey Identifier of the blob for which the parts will be fetched + * @param partNumber Optional part number for the blob to be retrieved * @return A future of {@link InputStreamContainer} containing the stream and stream metadata. */ CompletableFuture getBlobPartInputStreamContainer( S3AsyncClient s3AsyncClient, String bucketName, - String blobName, - int partNumber + String blobKey, + @Nullable Integer partNumber ) { - final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder() - .bucket(bucketName) - .key(blobName) - .partNumber(partNumber); + final boolean isMultipartObject = partNumber != null; + final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder().bucket(bucketName).key(blobKey); + + if (isMultipartObject) { + getObjectRequestBuilder.partNumber(partNumber); + } return SocketAccess.doPrivileged( () -> s3AsyncClient.getObject(getObjectRequestBuilder.build(), AsyncResponseTransformer.toBlockingInputStream()) - .thenApply(S3BlobContainer::transformResponseToInputStreamContainer) + .thenApply(response -> transformResponseToInputStreamContainer(response, isMultipartObject)) ); } /** * Transforms the stream response object from S3 into an {@link InputStreamContainer} * @param streamResponse Response stream object from S3 + * @param isMultipartObject Flag to denote a multipart object response * @return {@link InputStreamContainer} containing the stream and stream metadata */ // Package-Private for testing. - static InputStreamContainer transformResponseToInputStreamContainer(ResponseInputStream streamResponse) { + static InputStreamContainer transformResponseToInputStreamContainer( + ResponseInputStream streamResponse, + boolean isMultipartObject + ) { final GetObjectResponse getObjectResponse = streamResponse.response(); final String contentRange = getObjectResponse.contentRange(); final Long contentLength = getObjectResponse.contentLength(); - if (contentRange == null || contentLength == null) { + if ((isMultipartObject && contentRange == null) || contentLength == null) { throw SdkException.builder().message("Failed to fetch required metadata for blob part").build(); } - final Long offset = HttpRangeUtils.getStartOffsetFromRangeHeader(getObjectResponse.contentRange()); + final long offset = isMultipartObject ? HttpRangeUtils.getStartOffsetFromRangeHeader(getObjectResponse.contentRange()) : 0L; return new InputStreamContainer(streamResponse, getObjectResponse.contentLength(), offset); } diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index a87c060dcc60a..9817d7cd520ef 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -81,7 +81,6 @@ import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.unit.ByteSizeUnit; -import org.opensearch.repositories.s3.async.AsyncTransferManager; import org.opensearch.test.OpenSearchTestCase; import java.io.ByteArrayInputStream; @@ -100,7 +99,6 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -919,7 +917,7 @@ public void testListBlobsByPrefixInLexicographicOrderWithLimitGreaterThanNumberO testListBlobsByPrefixInLexicographicOrder(12, 2, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC); } - public void testReadBlobAsync() throws Exception { + public void testReadBlobAsyncMultiPart() throws Exception { final String bucketName = randomAlphaOfLengthBetween(1, 10); final String blobName = randomAlphaOfLengthBetween(1, 10); final String checksum = randomAlphaOfLength(10); @@ -932,11 +930,7 @@ public void testReadBlobAsync() throws Exception { final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference( AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null) ); - final AsyncTransferManager asyncTransferManager = new AsyncTransferManager( - 10000L, - mock(ExecutorService.class), - mock(ExecutorService.class) - ); + final S3BlobStore blobStore = mock(S3BlobStore.class); final BlobPath blobPath = new BlobPath(); @@ -944,7 +938,6 @@ public void testReadBlobAsync() throws Exception { when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher()); when(blobStore.serverSideEncryption()).thenReturn(false); when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference); - when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager); CompletableFuture getObjectAttributesResponseCompletableFuture = new CompletableFuture<>(); getObjectAttributesResponseCompletableFuture.complete( @@ -984,6 +977,60 @@ public void testReadBlobAsync() throws Exception { } } + public void testReadBlobAsyncSinglePart() throws Exception { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String checksum = randomAlphaOfLength(10); + + final int objectSize = 100; + + final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class); + final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference( + AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null) + ); + final S3BlobStore blobStore = mock(S3BlobStore.class); + final BlobPath blobPath = new BlobPath(); + + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher()); + when(blobStore.serverSideEncryption()).thenReturn(false); + when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference); + + CompletableFuture getObjectAttributesResponseCompletableFuture = new CompletableFuture<>(); + getObjectAttributesResponseCompletableFuture.complete( + GetObjectAttributesResponse.builder() + .checksum(Checksum.builder().checksumCRC32(checksum).build()) + .objectSize((long) objectSize) + .build() + ); + when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn( + getObjectAttributesResponseCompletableFuture + ); + + mockObjectResponse(s3AsyncClient, bucketName, blobName, objectSize); + + CountDownLatch countDownLatch = new CountDownLatch(1); + CountingCompletionListener readContextActionListener = new CountingCompletionListener<>(); + LatchedActionListener listener = new LatchedActionListener<>(readContextActionListener, countDownLatch); + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + blobContainer.readBlobAsync(blobName, listener); + countDownLatch.await(); + + assertEquals(1, readContextActionListener.getResponseCount()); + assertEquals(0, readContextActionListener.getFailureCount()); + ReadContext readContext = readContextActionListener.getResponse(); + assertEquals(1, readContext.getNumberOfParts()); + assertEquals(checksum, readContext.getBlobChecksum()); + assertEquals(objectSize, readContext.getBlobSize()); + + InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get(); + assertEquals(objectSize, inputStreamContainer.getContentLength()); + assertEquals(0, inputStreamContainer.getOffset()); + assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length); + + } + public void testReadBlobAsyncFailure() throws Exception { final String bucketName = randomAlphaOfLengthBetween(1, 10); final String blobName = randomAlphaOfLengthBetween(1, 10); @@ -996,11 +1043,7 @@ public void testReadBlobAsyncFailure() throws Exception { final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference( AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null) ); - final AsyncTransferManager asyncTransferManager = new AsyncTransferManager( - 10000L, - mock(ExecutorService.class), - mock(ExecutorService.class) - ); + final S3BlobStore blobStore = mock(S3BlobStore.class); final BlobPath blobPath = new BlobPath(); @@ -1008,7 +1051,6 @@ public void testReadBlobAsyncFailure() throws Exception { when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher()); when(blobStore.serverSideEncryption()).thenReturn(false); when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference); - when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager); CompletableFuture getObjectAttributesResponseCompletableFuture = new CompletableFuture<>(); getObjectAttributesResponseCompletableFuture.complete( @@ -1071,7 +1113,7 @@ public void testGetBlobPartInputStream() throws Exception { final String blobName = randomAlphaOfLengthBetween(1, 10); final String bucketName = randomAlphaOfLengthBetween(1, 10); final long contentLength = 10L; - final String contentRange = "bytes 0-10/100"; + final String contentRange = "bytes 10-20/100"; final InputStream inputStream = ResponseInputStream.nullInputStream(); final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class); @@ -1095,9 +1137,17 @@ public void testGetBlobPartInputStream() throws Exception { ) ).thenReturn(getObjectPartResponse); + // Header based offset in case of a multi part object request InputStreamContainer inputStreamContainer = blobContainer.getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, 0) .get(); + assertEquals(10, inputStreamContainer.getOffset()); + assertEquals(contentLength, inputStreamContainer.getContentLength()); + assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available()); + + // 0 offset in case of a single part object request + inputStreamContainer = blobContainer.getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, null).get(); + assertEquals(0, inputStreamContainer.getOffset()); assertEquals(contentLength, inputStreamContainer.getContentLength()); assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available()); @@ -1108,28 +1158,65 @@ public void testTransformResponseToInputStreamContainer() throws Exception { final long contentLength = 10L; final InputStream inputStream = ResponseInputStream.nullInputStream(); - final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class); - GetObjectResponse getObjectResponse = GetObjectResponse.builder().contentLength(contentLength).build(); + // Exception when content range absent for multipart object ResponseInputStream responseInputStreamNoRange = new ResponseInputStream<>(getObjectResponse, inputStream); - assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoRange)); + assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoRange, true)); + + // No exception when content range absent for single part object + ResponseInputStream responseInputStreamNoRangeSinglePart = new ResponseInputStream<>( + getObjectResponse, + inputStream + ); + InputStreamContainer inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer( + responseInputStreamNoRangeSinglePart, + false + ); + assertEquals(contentLength, inputStreamContainer.getContentLength()); + assertEquals(0, inputStreamContainer.getOffset()); + // Exception when length is absent getObjectResponse = GetObjectResponse.builder().contentRange(contentRange).build(); ResponseInputStream responseInputStreamNoContentLength = new ResponseInputStream<>( getObjectResponse, inputStream ); - assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoContentLength)); + assertThrows( + SdkException.class, + () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoContentLength, true) + ); + // No exception when range and length both are present getObjectResponse = GetObjectResponse.builder().contentRange(contentRange).contentLength(contentLength).build(); ResponseInputStream responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream); - InputStreamContainer inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(responseInputStream); + inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(responseInputStream, true); assertEquals(contentLength, inputStreamContainer.getContentLength()); assertEquals(0, inputStreamContainer.getOffset()); assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available()); } + private void mockObjectResponse(S3AsyncClient s3AsyncClient, String bucketName, String blobName, int objectSize) { + + final InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(objectSize)); + + GetObjectResponse getObjectResponse = GetObjectResponse.builder().contentLength((long) objectSize).build(); + + CompletableFuture> getObjectPartResponse = new CompletableFuture<>(); + ResponseInputStream responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream); + getObjectPartResponse.complete(responseInputStream); + + GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(blobName).build(); + + when( + s3AsyncClient.getObject( + eq(getObjectRequest), + ArgumentMatchers.>>any() + ) + ).thenReturn(getObjectPartResponse); + + } + private void mockObjectPartResponse( S3AsyncClient s3AsyncClient, String bucketName,