From 421d3c5360fa15db1408c00a2201204a28c2c66a Mon Sep 17 00:00:00 2001 From: Kunal Kotwani Date: Tue, 19 Sep 2023 16:59:28 -0700 Subject: [PATCH] Add async blob read support for encrypted containers Signed-off-by: Kunal Kotwani --- CHANGELOG.md | 3 +- ...syncMultiStreamEncryptedBlobContainer.java | 95 +++++--------- .../blobstore/EncryptedBlobContainer.java | 2 +- ...ultiStreamEncryptedBlobContainerTests.java | 122 ++++++++++++++++++ .../read/listener/ListenerTestUtils.java | 2 +- 5 files changed, 158 insertions(+), 66 deletions(-) create mode 100644 server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cb5d39c1a764..763f0bd7925da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681)) - Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694)) +- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131)) ### Dependencies - Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575)) @@ -108,4 +109,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Security [Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD -[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.11...2.x \ No newline at end of file +[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.11...2.x diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index 9021ced7d9af6..46e79781d8b73 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.InputStream; +import java.io.UncheckedIOException; import java.util.List; import java.util.stream.Collectors; @@ -47,24 +48,16 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp @Override public void readBlobAsync(String blobName, ActionListener listener) { - DecryptingReadContextListener decryptingReadContextListener = new DecryptingReadContextListener<>( - listener, - cryptoHandler, - getEncryptedHeaderContentSupplier(blobName) - ); - blobContainer.readBlobAsync(blobName, decryptingReadContextListener); - } + try { + ActionListener decryptingCompletionListener = ActionListener.map( + listener, + readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, getEncryptedHeaderContentSupplier(blobName)) + ); - private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { - return (start, end) -> { - byte[] buffer; - int length = (int) (end - start + 1); - try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) { - buffer = new byte[length]; - inputStream.readNBytes(buffer, (int) start, buffer.length); - } - return buffer; - }; + blobContainer.readBlobAsync(blobName, decryptingCompletionListener); + } catch (Exception e) { + listener.onFailure(e); + } } @Override @@ -124,47 +117,17 @@ public InputStreamContainer provideStream(int partNumber) throws IOException { } - static class DecryptingReadContextListener implements ActionListener { - - private final ActionListener completionListener; - private final CryptoHandler cryptoHandler; - private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier; - - public DecryptingReadContextListener( - ActionListener completionListener, - CryptoHandler cryptoHandler, - EncryptedHeaderContentSupplier headerContentSupplier - ) { - this.completionListener = completionListener; - this.cryptoHandler = cryptoHandler; - this.encryptedHeaderContentSupplier = headerContentSupplier; - } - - @Override - public void onResponse(ReadContext readContext) { - try { - DecryptedReadContext decryptedReadContext = new DecryptedReadContext<>( - readContext, - cryptoHandler, - encryptedHeaderContentSupplier - ); - completionListener.onResponse(decryptedReadContext); - } catch (Exception e) { - onFailure(e); - } - } - - @Override - public void onFailure(Exception e) { - completionListener.onFailure(e); - } - } - + /** + * DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around + * the encrypted object + * @param Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance + * @param Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance + */ static class DecryptedReadContext extends ReadContext { - private final U cryptoContext; private final CryptoHandler cryptoHandler; - private final long fileSize; + private final U cryptoContext; + private Long blobSize; public DecryptedReadContext( ReadContext readContext, @@ -175,15 +138,18 @@ public DecryptedReadContext( try { this.cryptoHandler = cryptoHandler; this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier); - this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize()); } catch (IOException e) { - throw new RuntimeException(e); + throw new UncheckedIOException(e); } } @Override public long getBlobSize() { - return fileSize; + // initializes the value lazily + if (blobSize == null) { + this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize()); + } + return this.blobSize; } @Override @@ -191,6 +157,11 @@ public List getPartStreams() { return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList()); } + /** + * Transforms an encrypted {@link InputStreamContainer} to a decrypted instance + * @param inputStreamContainer encrypted input stream container instance + * @return decrypted input stream container instance + */ private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) { long startOfStream = inputStreamContainer.getOffset(); long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1; @@ -202,11 +173,9 @@ private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer in long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0]; long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1; - return new InputStreamContainer( - decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()), - adjustedPos, - adjustedLength - ); + final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider() + .apply(inputStreamContainer.getInputStream()); + return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos); } } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java index 475d891ea9336..d0933741339d9 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java @@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException { return cryptoHandler.createDecryptingStream(inputStream); } - private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { + EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { return (start, end) -> { byte[] buffer; int length = (int) (end - start + 1); diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java new file mode 100644 index 0000000000000..b9cdcd84f6ca6 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore; + +import org.opensearch.common.Randomness; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils; +import org.opensearch.common.crypto.CryptoHandler; +import org.opensearch.common.crypto.DecryptedRangedStreamProvider; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.function.UnaryOperator; + +import org.mockito.Mockito; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase { + + // Tests the happy path scenario for decrypting a read context + @SuppressWarnings("unchecked") + public void testReadBlobAsync() throws Exception { + String testBlobName = "testBlobName"; + int size = 100; + + // Mock objects needed for the test + AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class); + CryptoHandler cryptoHandler = mock(CryptoHandler.class); + Object cryptoContext = mock(Object.class); + when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext); + when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size); + long[] adjustedRanges = { 0, size - 1 }; + DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity()); + when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider); + + // Objects needed for API call + final byte[] data = new byte[size]; + Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); + final ListenerTestUtils.CountingCompletionListener completionListener = + new ListenerTestUtils.CountingCompletionListener<>(); + final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + + Mockito.doAnswer(invocation -> { + ActionListener readContextActionListener = invocation.getArgument(1); + readContextActionListener.onResponse(readContext); + return null; + }).when(blobContainer).readBlobAsync(eq(testBlobName), any()); + + AsyncMultiStreamEncryptedBlobContainer asyncMultiStreamEncryptedBlobContainer = + new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler); + asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener); + + // Assert results + ReadContext response = completionListener.getResponse(); + assertEquals(0, completionListener.getFailureCount()); + assertEquals(1, completionListener.getResponseCount()); + assertNull(completionListener.getException()); + + assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext); + assertEquals(1, response.getNumberOfParts()); + assertEquals(size, response.getBlobSize()); + + InputStreamContainer responseContainer = response.getPartStreams().get(0); + assertEquals(0, responseContainer.getOffset()); + assertEquals(size, responseContainer.getContentLength()); + assertEquals(100, responseContainer.getInputStream().available()); + } + + // Tests the exception scenario for decrypting a read context + @SuppressWarnings("unchecked") + public void testReadBlobAsyncException() throws Exception { + String testBlobName = "testBlobName"; + int size = 100; + + // Mock objects needed for the test + AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class); + CryptoHandler cryptoHandler = mock(CryptoHandler.class); + when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException()); + + // Objects needed for API call + final byte[] data = new byte[size]; + Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); + final ListenerTestUtils.CountingCompletionListener completionListener = + new ListenerTestUtils.CountingCompletionListener<>(); + final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + + Mockito.doAnswer(invocation -> { + ActionListener readContextActionListener = invocation.getArgument(1); + readContextActionListener.onResponse(readContext); + return null; + }).when(blobContainer).readBlobAsync(eq(testBlobName), any()); + + AsyncMultiStreamEncryptedBlobContainer asyncMultiStreamEncryptedBlobContainer = + new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler); + asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener); + + // Assert results + assertEquals(1, completionListener.getFailureCount()); + assertEquals(0, completionListener.getResponseCount()); + assertNull(completionListener.getResponse()); + assertTrue(completionListener.getException() instanceof UncheckedIOException); + } + +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java index 1e9450c83e3ab..a3a32f6db2148 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java @@ -19,7 +19,7 @@ public class ListenerTestUtils { * CountingCompletionListener acts as a verification instance for wrapping listener based calls. * Keeps track of the last response, failure and count of response and failure invocations. */ - static class CountingCompletionListener implements ActionListener { + public static class CountingCompletionListener implements ActionListener { private int responseCount; private int failureCount; private T response;