From 689adc157442c9482e8960ab6d76eb3ea85ac3e6 Mon Sep 17 00:00:00 2001 From: Kunal Kotwani Date: Tue, 3 Oct 2023 22:30:52 -0700 Subject: [PATCH] Add support for encrypted async blob read (#10131) (#10346) * Add support for encrypted async blob read Signed-off-by: Kunal Kotwani * Add async blob read support for encrypted containers Signed-off-by: Kunal Kotwani --------- Signed-off-by: Kunal Kotwani (cherry picked from commit c4c4ad84d995f75f8749a0f99aa8a1ecc3b71760) --- CHANGELOG.md | 1 + ...syncMultiStreamEncryptedBlobContainer.java | 77 +++++++++-- .../blobstore/EncryptedBlobContainer.java | 2 +- .../blobstore/stream/read/ReadContext.java | 6 + ...ultiStreamEncryptedBlobContainerTests.java | 121 ++++++++++++++++++ .../read/listener/ListenerTestUtils.java | 2 +- 6 files changed, 199 insertions(+), 10 deletions(-) create mode 100644 server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b59ddb5227f9e..b99f588fc58c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681)) - Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694)) - [Telemetry-Otel] Added support for OtlpGrpcSpanExporter exporter ([#9666](https://github.com/opensearch-project/OpenSearch/pull/9666)) +- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131)) - Implement Visitor Design pattern in QueryBuilder to enable the capability to traverse through the complex QueryBuilder tree. ([#10110](https://github.com/opensearch-project/OpenSearch/pull/10110)) - Add capability to restrict async durability mode for remote indexes ([#10189](https://github.com/opensearch-project/OpenSearch/pull/10189)) - Add Doc Status Counter for Indexing Engine ([#4562](https://github.com/opensearch-project/OpenSearch/issues/4562)) diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index 07a0b49df47ff..c64dc6b9e3ae4 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -12,12 +12,14 @@ import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.crypto.CryptoHandler; +import org.opensearch.common.crypto.DecryptedRangedStreamProvider; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.nio.file.Path; +import java.io.InputStream; +import java.util.List; +import java.util.stream.Collectors; /** * EncryptedBlobContainer is an encrypted BlobContainer that is backed by a @@ -44,12 +46,17 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp @Override public void readBlobAsync(String blobName, ActionListener listener) { - throw new UnsupportedOperationException(); - } - - @Override - public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { - throw new UnsupportedOperationException(); + try { + final U cryptoContext = cryptoHandler.loadEncryptionMetadata(getEncryptedHeaderContentSupplier(blobName)); + ActionListener decryptingCompletionListener = ActionListener.map( + listener, + readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, cryptoContext) + ); + + blobContainer.readBlobAsync(blobName, decryptingCompletionListener); + } catch (Exception e) { + listener.onFailure(e); + } } @Override @@ -108,4 +115,58 @@ public InputStreamContainer provideStream(int partNumber) throws IOException { } } + + /** + * DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around + * the encrypted object + * @param Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance + * @param Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance + */ + static class DecryptedReadContext extends ReadContext { + + private final CryptoHandler cryptoHandler; + private final U cryptoContext; + private Long blobSize; + + public DecryptedReadContext(ReadContext readContext, CryptoHandler cryptoHandler, U cryptoContext) { + super(readContext); + this.cryptoHandler = cryptoHandler; + this.cryptoContext = cryptoContext; + } + + @Override + public long getBlobSize() { + // initializes the value lazily + if (blobSize == null) { + this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize()); + } + return this.blobSize; + } + + @Override + public List getPartStreams() { + return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList()); + } + + /** + * Transforms an encrypted {@link InputStreamContainer} to a decrypted instance + * @param inputStreamContainer encrypted input stream container instance + * @return decrypted input stream container instance + */ + private InputStreamContainer decryptInputStreamContainer(InputStreamContainer inputStreamContainer) { + long startOfStream = inputStreamContainer.getOffset(); + long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1; + DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange( + cryptoContext, + startOfStream, + endOfStream + ); + + long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0]; + long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1; + final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider() + .apply(inputStreamContainer.getInputStream()); + return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos); + } + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java index 475d891ea9336..d0933741339d9 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java @@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException { return cryptoHandler.createDecryptingStream(inputStream); } - private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { + EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) { return (start, end) -> { byte[] buffer; int length = (int) (end - start + 1); diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java index 4ba17959f8040..2c305fb03c475 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -28,6 +28,12 @@ public ReadContext(long blobSize, List partStreams, String this.blobChecksum = blobChecksum; } + public ReadContext(ReadContext readContext) { + this.blobSize = readContext.blobSize; + this.partStreams = readContext.partStreams; + this.blobChecksum = readContext.blobChecksum; + } + public String getBlobChecksum() { return blobChecksum; } diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java new file mode 100644 index 0000000000000..947a4f9b1c9ab --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore; + +import org.opensearch.common.Randomness; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils; +import org.opensearch.common.crypto.CryptoHandler; +import org.opensearch.common.crypto.DecryptedRangedStreamProvider; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.List; +import java.util.function.UnaryOperator; + +import org.mockito.Mockito; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase { + + // Tests the happy path scenario for decrypting a read context + @SuppressWarnings("unchecked") + public void testReadBlobAsync() throws Exception { + String testBlobName = "testBlobName"; + int size = 100; + + // Mock objects needed for the test + AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class); + CryptoHandler cryptoHandler = mock(CryptoHandler.class); + Object cryptoContext = mock(Object.class); + when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext); + when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size); + long[] adjustedRanges = { 0, size - 1 }; + DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity()); + when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider); + + // Objects needed for API call + final byte[] data = new byte[size]; + Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); + final ListenerTestUtils.CountingCompletionListener completionListener = + new ListenerTestUtils.CountingCompletionListener<>(); + final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + + Mockito.doAnswer(invocation -> { + ActionListener readContextActionListener = invocation.getArgument(1); + readContextActionListener.onResponse(readContext); + return null; + }).when(blobContainer).readBlobAsync(eq(testBlobName), any()); + + AsyncMultiStreamEncryptedBlobContainer asyncMultiStreamEncryptedBlobContainer = + new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler); + asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener); + + // Assert results + ReadContext response = completionListener.getResponse(); + assertEquals(0, completionListener.getFailureCount()); + assertEquals(1, completionListener.getResponseCount()); + assertNull(completionListener.getException()); + + assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext); + assertEquals(1, response.getNumberOfParts()); + assertEquals(size, response.getBlobSize()); + + InputStreamContainer responseContainer = response.getPartStreams().get(0); + assertEquals(0, responseContainer.getOffset()); + assertEquals(size, responseContainer.getContentLength()); + assertEquals(100, responseContainer.getInputStream().available()); + } + + // Tests the exception scenario for decrypting a read context + @SuppressWarnings("unchecked") + public void testReadBlobAsyncException() throws Exception { + String testBlobName = "testBlobName"; + int size = 100; + + // Mock objects needed for the test + AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class); + CryptoHandler cryptoHandler = mock(CryptoHandler.class); + when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException()); + + // Objects needed for API call + final byte[] data = new byte[size]; + Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); + final ListenerTestUtils.CountingCompletionListener completionListener = + new ListenerTestUtils.CountingCompletionListener<>(); + final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + + Mockito.doAnswer(invocation -> { + ActionListener readContextActionListener = invocation.getArgument(1); + readContextActionListener.onResponse(readContext); + return null; + }).when(blobContainer).readBlobAsync(eq(testBlobName), any()); + + AsyncMultiStreamEncryptedBlobContainer asyncMultiStreamEncryptedBlobContainer = + new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler); + asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener); + + // Assert results + assertEquals(1, completionListener.getFailureCount()); + assertEquals(0, completionListener.getResponseCount()); + assertNull(completionListener.getResponse()); + assertTrue(completionListener.getException() instanceof IOException); + } + +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java index 1e9450c83e3ab..a3a32f6db2148 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java @@ -19,7 +19,7 @@ public class ListenerTestUtils { * CountingCompletionListener acts as a verification instance for wrapping listener based calls. * Keeps track of the last response, failure and count of response and failure invocations. */ - static class CountingCompletionListener implements ActionListener { + public static class CountingCompletionListener implements ActionListener { private int responseCount; private int failureCount; private T response;