From bb22c12d3416e9f70ffadec00b10cc86de21231e Mon Sep 17 00:00:00 2001 From: Kunal Kotwani <kkotwani@amazon.com> Date: Mon, 11 Sep 2023 20:11:25 -0700 Subject: [PATCH 1/2] Add support for encrypted async blob read Signed-off-by: Kunal Kotwani <kkotwani@amazon.com> --- ...syncMultiStreamEncryptedBlobContainer.java | 113 +++++++++++++++++- .../blobstore/stream/read/ReadContext.java | 6 + 2 files changed, 113 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index 07a0b49df47ff..9021ced7d9af6 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -12,12 +12,15 @@ import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.crypto.CryptoHandler; +import org.opensearch.common.crypto.DecryptedRangedStreamProvider; +import org.opensearch.common.crypto.EncryptedHeaderContentSupplier; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.nio.file.Path; +import java.io.InputStream; +import java.util.List; +import java.util.stream.Collectors; /** * EncryptedBlobContainer is an encrypted BlobContainer that is backed by a @@ -44,12 +47,24 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp @Override public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) { - throw new UnsupportedOperationException(); + DecryptingReadContextListener<T, U> decryptingReadContextListener = new DecryptingReadContextListener<>( + listener, + cryptoHandler, + getEncryptedHeaderContentSupplier(blobName) + ); + blobContainer.readBlobAsync(blobName, decryptingReadContextListener); } - @Override - public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) { - throw new UnsupportedOperationException(); + private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { + return (start, end) -> { + byte[] buffer; + int length = (int) (end - start + 1); + try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) { + buffer = new byte[length]; + inputStream.readNBytes(buffer, (int) start, buffer.length); + } + return buffer; + }; } @Override @@ -108,4 +123,90 @@ public InputStreamContainer provideStream(int partNumber) throws IOException { } } + + static class DecryptingReadContextListener<T, U> implements ActionListener<ReadContext> { + + private final ActionListener<ReadContext> completionListener; + private final CryptoHandler<T, U> cryptoHandler; + private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier; + + public DecryptingReadContextListener( + ActionListener<ReadContext> completionListener, + CryptoHandler<T, U> cryptoHandler, + EncryptedHeaderContentSupplier headerContentSupplier + ) { + this.completionListener = completionListener; + this.cryptoHandler = cryptoHandler; + this.encryptedHeaderContentSupplier = headerContentSupplier; + } + + @Override + public void onResponse(ReadContext readContext) { + try { + DecryptedReadContext<T, U> decryptedReadContext = new DecryptedReadContext<>( + readContext, + cryptoHandler, + encryptedHeaderContentSupplier + ); + completionListener.onResponse(decryptedReadContext); + } catch (Exception e) { + onFailure(e); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } + } + + static class DecryptedReadContext<T, U> extends ReadContext { + + private final U cryptoContext; + private final CryptoHandler<T, U> cryptoHandler; + private final long fileSize; + + public DecryptedReadContext( + ReadContext readContext, + CryptoHandler<T, U> cryptoHandler, + EncryptedHeaderContentSupplier headerContentSupplier + ) { + super(readContext); + try { + this.cryptoHandler = cryptoHandler; + this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier); + this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public long getBlobSize() { + return fileSize; + } + + @Override + public List<InputStreamContainer> getPartStreams() { + return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList()); + } + + private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) { + long startOfStream = inputStreamContainer.getOffset(); + long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1; + DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange( + cryptoContext, + startOfStream, + endOfStream + ); + + long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0]; + long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1; + return new InputStreamContainer( + decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()), + adjustedPos, + adjustedLength + ); + } + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java index 4ba17959f8040..2c305fb03c475 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -28,6 +28,12 @@ public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String this.blobChecksum = blobChecksum; } + public ReadContext(ReadContext readContext) { + this.blobSize = readContext.blobSize; + this.partStreams = readContext.partStreams; + this.blobChecksum = readContext.blobChecksum; + } + public String getBlobChecksum() { return blobChecksum; } From 5551aafffb25db7b838db1963a31501070d79864 Mon Sep 17 00:00:00 2001 From: Kunal Kotwani <kkotwani@amazon.com> Date: Tue, 19 Sep 2023 16:59:28 -0700 Subject: [PATCH 2/2] Add async blob read support for encrypted containers Signed-off-by: Kunal Kotwani <kkotwani@amazon.com> --- CHANGELOG.md | 1 + ...syncMultiStreamEncryptedBlobContainer.java | 112 ++++++---------- .../blobstore/EncryptedBlobContainer.java | 2 +- ...ultiStreamEncryptedBlobContainerTests.java | 121 ++++++++++++++++++ .../read/listener/ListenerTestUtils.java | 2 +- 5 files changed, 160 insertions(+), 78 deletions(-) create mode 100644 server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 36d566805ebf4..44db1a5512840 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681)) - Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694)) - [Telemetry-Otel] Added support for OtlpGrpcSpanExporter exporter ([#9666](https://github.com/opensearch-project/OpenSearch/pull/9666)) +- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131)) ### Dependencies - Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575)) diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index 9021ced7d9af6..c64dc6b9e3ae4 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -13,7 +13,6 @@ import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.crypto.CryptoHandler; import org.opensearch.common.crypto.DecryptedRangedStreamProvider; -import org.opensearch.common.crypto.EncryptedHeaderContentSupplier; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; @@ -47,24 +46,17 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp @Override public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) { - DecryptingReadContextListener<T, U> decryptingReadContextListener = new DecryptingReadContextListener<>( - listener, - cryptoHandler, - getEncryptedHeaderContentSupplier(blobName) - ); - blobContainer.readBlobAsync(blobName, decryptingReadContextListener); - } + try { + final U cryptoContext = cryptoHandler.loadEncryptionMetadata(getEncryptedHeaderContentSupplier(blobName)); + ActionListener<ReadContext> decryptingCompletionListener = ActionListener.map( + listener, + readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, cryptoContext) + ); - private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { - return (start, end) -> { - byte[] buffer; - int length = (int) (end - start + 1); - try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) { - buffer = new byte[length]; - inputStream.readNBytes(buffer, (int) start, buffer.length); - } - return buffer; - }; + blobContainer.readBlobAsync(blobName, decryptingCompletionListener); + } catch (Exception e) { + listener.onFailure(e); + } } @Override @@ -124,74 +116,44 @@ public InputStreamContainer provideStream(int partNumber) throws IOException { } - static class DecryptingReadContextListener<T, U> implements ActionListener<ReadContext> { - - private final ActionListener<ReadContext> completionListener; - private final CryptoHandler<T, U> cryptoHandler; - private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier; - - public DecryptingReadContextListener( - ActionListener<ReadContext> completionListener, - CryptoHandler<T, U> cryptoHandler, - EncryptedHeaderContentSupplier headerContentSupplier - ) { - this.completionListener = completionListener; - this.cryptoHandler = cryptoHandler; - this.encryptedHeaderContentSupplier = headerContentSupplier; - } - - @Override - public void onResponse(ReadContext readContext) { - try { - DecryptedReadContext<T, U> decryptedReadContext = new DecryptedReadContext<>( - readContext, - cryptoHandler, - encryptedHeaderContentSupplier - ); - completionListener.onResponse(decryptedReadContext); - } catch (Exception e) { - onFailure(e); - } - } - - @Override - public void onFailure(Exception e) { - completionListener.onFailure(e); - } - } - + /** + * DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around + * the encrypted object + * @param <T> Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance + * @param <U> Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance + */ static class DecryptedReadContext<T, U> extends ReadContext { - private final U cryptoContext; private final CryptoHandler<T, U> cryptoHandler; - private final long fileSize; + private final U cryptoContext; + private Long blobSize; - public DecryptedReadContext( - ReadContext readContext, - CryptoHandler<T, U> cryptoHandler, - EncryptedHeaderContentSupplier headerContentSupplier - ) { + public DecryptedReadContext(ReadContext readContext, CryptoHandler<T, U> cryptoHandler, U cryptoContext) { super(readContext); - try { - this.cryptoHandler = cryptoHandler; - this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier); - this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize()); - } catch (IOException e) { - throw new RuntimeException(e); - } + this.cryptoHandler = cryptoHandler; + this.cryptoContext = cryptoContext; } @Override public long getBlobSize() { - return fileSize; + // initializes the value lazily + if (blobSize == null) { + this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize()); + } + return this.blobSize; } @Override public List<InputStreamContainer> getPartStreams() { - return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList()); + return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList()); } - private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) { + /** + * Transforms an encrypted {@link InputStreamContainer} to a decrypted instance + * @param inputStreamContainer encrypted input stream container instance + * @return decrypted input stream container instance + */ + private InputStreamContainer decryptInputStreamContainer(InputStreamContainer inputStreamContainer) { long startOfStream = inputStreamContainer.getOffset(); long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1; DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange( @@ -202,11 +164,9 @@ private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer in long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0]; long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1; - return new InputStreamContainer( - decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()), - adjustedPos, - adjustedLength - ); + final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider() + .apply(inputStreamContainer.getInputStream()); + return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos); } } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java index 475d891ea9336..d0933741339d9 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java @@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException { return cryptoHandler.createDecryptingStream(inputStream); } - private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { + EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { return (start, end) -> { byte[] buffer; int length = (int) (end - start + 1); diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java new file mode 100644 index 0000000000000..947a4f9b1c9ab --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore; + +import org.opensearch.common.Randomness; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils; +import org.opensearch.common.crypto.CryptoHandler; +import org.opensearch.common.crypto.DecryptedRangedStreamProvider; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.List; +import java.util.function.UnaryOperator; + +import org.mockito.Mockito; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase { + + // Tests the happy path scenario for decrypting a read context + @SuppressWarnings("unchecked") + public void testReadBlobAsync() throws Exception { + String testBlobName = "testBlobName"; + int size = 100; + + // Mock objects needed for the test + AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class); + CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class); + Object cryptoContext = mock(Object.class); + when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext); + when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size); + long[] adjustedRanges = { 0, size - 1 }; + DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity()); + when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider); + + // Objects needed for API call + final byte[] data = new byte[size]; + Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); + final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener = + new ListenerTestUtils.CountingCompletionListener<>(); + final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + + Mockito.doAnswer(invocation -> { + ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1); + readContextActionListener.onResponse(readContext); + return null; + }).when(blobContainer).readBlobAsync(eq(testBlobName), any()); + + AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer = + new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler); + asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener); + + // Assert results + ReadContext response = completionListener.getResponse(); + assertEquals(0, completionListener.getFailureCount()); + assertEquals(1, completionListener.getResponseCount()); + assertNull(completionListener.getException()); + + assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext); + assertEquals(1, response.getNumberOfParts()); + assertEquals(size, response.getBlobSize()); + + InputStreamContainer responseContainer = response.getPartStreams().get(0); + assertEquals(0, responseContainer.getOffset()); + assertEquals(size, responseContainer.getContentLength()); + assertEquals(100, responseContainer.getInputStream().available()); + } + + // Tests the exception scenario for decrypting a read context + @SuppressWarnings("unchecked") + public void testReadBlobAsyncException() throws Exception { + String testBlobName = "testBlobName"; + int size = 100; + + // Mock objects needed for the test + AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class); + CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class); + when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException()); + + // Objects needed for API call + final byte[] data = new byte[size]; + Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); + final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener = + new ListenerTestUtils.CountingCompletionListener<>(); + final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + + Mockito.doAnswer(invocation -> { + ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1); + readContextActionListener.onResponse(readContext); + return null; + }).when(blobContainer).readBlobAsync(eq(testBlobName), any()); + + AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer = + new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler); + asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener); + + // Assert results + assertEquals(1, completionListener.getFailureCount()); + assertEquals(0, completionListener.getResponseCount()); + assertNull(completionListener.getResponse()); + assertTrue(completionListener.getException() instanceof IOException); + } + +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java index 1e9450c83e3ab..a3a32f6db2148 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java @@ -19,7 +19,7 @@ public class ListenerTestUtils { * CountingCompletionListener acts as a verification instance for wrapping listener based calls. * Keeps track of the last response, failure and count of response and failure invocations. */ - static class CountingCompletionListener<T> implements ActionListener<T> { + public static class CountingCompletionListener<T> implements ActionListener<T> { private int responseCount; private int failureCount; private T response;