From bb22c12d3416e9f70ffadec00b10cc86de21231e Mon Sep 17 00:00:00 2001
From: Kunal Kotwani <kkotwani@amazon.com>
Date: Mon, 11 Sep 2023 20:11:25 -0700
Subject: [PATCH 1/2] Add support for encrypted async blob read

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
---
 ...syncMultiStreamEncryptedBlobContainer.java | 113 +++++++++++++++++-
 .../blobstore/stream/read/ReadContext.java    |   6 +
 2 files changed, 113 insertions(+), 6 deletions(-)

diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java
index 07a0b49df47ff..9021ced7d9af6 100644
--- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java
+++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java
@@ -12,12 +12,15 @@
 import org.opensearch.common.blobstore.stream.read.ReadContext;
 import org.opensearch.common.blobstore.stream.write.WriteContext;
 import org.opensearch.common.crypto.CryptoHandler;
+import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
+import org.opensearch.common.crypto.EncryptedHeaderContentSupplier;
 import org.opensearch.common.io.InputStreamContainer;
 import org.opensearch.core.action.ActionListener;
-import org.opensearch.threadpool.ThreadPool;
 
 import java.io.IOException;
-import java.nio.file.Path;
+import java.io.InputStream;
+import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * EncryptedBlobContainer is an encrypted BlobContainer that is backed by a
@@ -44,12 +47,24 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp
 
     @Override
     public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
-        throw new UnsupportedOperationException();
+        DecryptingReadContextListener<T, U> decryptingReadContextListener = new DecryptingReadContextListener<>(
+            listener,
+            cryptoHandler,
+            getEncryptedHeaderContentSupplier(blobName)
+        );
+        blobContainer.readBlobAsync(blobName, decryptingReadContextListener);
     }
 
-    @Override
-    public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
-        throw new UnsupportedOperationException();
+    private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
+        return (start, end) -> {
+            byte[] buffer;
+            int length = (int) (end - start + 1);
+            try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) {
+                buffer = new byte[length];
+                inputStream.readNBytes(buffer, (int) start, buffer.length);
+            }
+            return buffer;
+        };
     }
 
     @Override
@@ -108,4 +123,90 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {
         }
 
     }
+
+    static class DecryptingReadContextListener<T, U> implements ActionListener<ReadContext> {
+
+        private final ActionListener<ReadContext> completionListener;
+        private final CryptoHandler<T, U> cryptoHandler;
+        private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier;
+
+        public DecryptingReadContextListener(
+            ActionListener<ReadContext> completionListener,
+            CryptoHandler<T, U> cryptoHandler,
+            EncryptedHeaderContentSupplier headerContentSupplier
+        ) {
+            this.completionListener = completionListener;
+            this.cryptoHandler = cryptoHandler;
+            this.encryptedHeaderContentSupplier = headerContentSupplier;
+        }
+
+        @Override
+        public void onResponse(ReadContext readContext) {
+            try {
+                DecryptedReadContext<T, U> decryptedReadContext = new DecryptedReadContext<>(
+                    readContext,
+                    cryptoHandler,
+                    encryptedHeaderContentSupplier
+                );
+                completionListener.onResponse(decryptedReadContext);
+            } catch (Exception e) {
+                onFailure(e);
+            }
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            completionListener.onFailure(e);
+        }
+    }
+
+    static class DecryptedReadContext<T, U> extends ReadContext {
+
+        private final U cryptoContext;
+        private final CryptoHandler<T, U> cryptoHandler;
+        private final long fileSize;
+
+        public DecryptedReadContext(
+            ReadContext readContext,
+            CryptoHandler<T, U> cryptoHandler,
+            EncryptedHeaderContentSupplier headerContentSupplier
+        ) {
+            super(readContext);
+            try {
+                this.cryptoHandler = cryptoHandler;
+                this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier);
+                this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize());
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        @Override
+        public long getBlobSize() {
+            return fileSize;
+        }
+
+        @Override
+        public List<InputStreamContainer> getPartStreams() {
+            return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList());
+        }
+
+        private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) {
+            long startOfStream = inputStreamContainer.getOffset();
+            long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1;
+            DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange(
+                cryptoContext,
+                startOfStream,
+                endOfStream
+            );
+
+            long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0];
+            long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1;
+            return new InputStreamContainer(
+                decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()),
+                adjustedPos,
+                adjustedLength
+            );
+        }
+    }
 }
diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java
index 4ba17959f8040..2c305fb03c475 100644
--- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java
+++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java
@@ -28,6 +28,12 @@ public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String
         this.blobChecksum = blobChecksum;
     }
 
+    public ReadContext(ReadContext readContext) {
+        this.blobSize = readContext.blobSize;
+        this.partStreams = readContext.partStreams;
+        this.blobChecksum = readContext.blobChecksum;
+    }
+
     public String getBlobChecksum() {
         return blobChecksum;
     }

From 5551aafffb25db7b838db1963a31501070d79864 Mon Sep 17 00:00:00 2001
From: Kunal Kotwani <kkotwani@amazon.com>
Date: Tue, 19 Sep 2023 16:59:28 -0700
Subject: [PATCH 2/2] Add async blob read support for encrypted containers

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
---
 CHANGELOG.md                                  |   1 +
 ...syncMultiStreamEncryptedBlobContainer.java | 112 ++++++----------
 .../blobstore/EncryptedBlobContainer.java     |   2 +-
 ...ultiStreamEncryptedBlobContainerTests.java | 121 ++++++++++++++++++
 .../read/listener/ListenerTestUtils.java      |   2 +-
 5 files changed, 160 insertions(+), 78 deletions(-)
 create mode 100644 server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 36d566805ebf4..44db1a5512840 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681))
 - Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694))
 - [Telemetry-Otel] Added support for OtlpGrpcSpanExporter exporter ([#9666](https://github.com/opensearch-project/OpenSearch/pull/9666))
+- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131))
 
 ### Dependencies
 - Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575))
diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java
index 9021ced7d9af6..c64dc6b9e3ae4 100644
--- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java
+++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java
@@ -13,7 +13,6 @@
 import org.opensearch.common.blobstore.stream.write.WriteContext;
 import org.opensearch.common.crypto.CryptoHandler;
 import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
-import org.opensearch.common.crypto.EncryptedHeaderContentSupplier;
 import org.opensearch.common.io.InputStreamContainer;
 import org.opensearch.core.action.ActionListener;
 
@@ -47,24 +46,17 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp
 
     @Override
     public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
-        DecryptingReadContextListener<T, U> decryptingReadContextListener = new DecryptingReadContextListener<>(
-            listener,
-            cryptoHandler,
-            getEncryptedHeaderContentSupplier(blobName)
-        );
-        blobContainer.readBlobAsync(blobName, decryptingReadContextListener);
-    }
+        try {
+            final U cryptoContext = cryptoHandler.loadEncryptionMetadata(getEncryptedHeaderContentSupplier(blobName));
+            ActionListener<ReadContext> decryptingCompletionListener = ActionListener.map(
+                listener,
+                readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, cryptoContext)
+            );
 
-    private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
-        return (start, end) -> {
-            byte[] buffer;
-            int length = (int) (end - start + 1);
-            try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) {
-                buffer = new byte[length];
-                inputStream.readNBytes(buffer, (int) start, buffer.length);
-            }
-            return buffer;
-        };
+            blobContainer.readBlobAsync(blobName, decryptingCompletionListener);
+        } catch (Exception e) {
+            listener.onFailure(e);
+        }
     }
 
     @Override
@@ -124,74 +116,44 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {
 
     }
 
-    static class DecryptingReadContextListener<T, U> implements ActionListener<ReadContext> {
-
-        private final ActionListener<ReadContext> completionListener;
-        private final CryptoHandler<T, U> cryptoHandler;
-        private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier;
-
-        public DecryptingReadContextListener(
-            ActionListener<ReadContext> completionListener,
-            CryptoHandler<T, U> cryptoHandler,
-            EncryptedHeaderContentSupplier headerContentSupplier
-        ) {
-            this.completionListener = completionListener;
-            this.cryptoHandler = cryptoHandler;
-            this.encryptedHeaderContentSupplier = headerContentSupplier;
-        }
-
-        @Override
-        public void onResponse(ReadContext readContext) {
-            try {
-                DecryptedReadContext<T, U> decryptedReadContext = new DecryptedReadContext<>(
-                    readContext,
-                    cryptoHandler,
-                    encryptedHeaderContentSupplier
-                );
-                completionListener.onResponse(decryptedReadContext);
-            } catch (Exception e) {
-                onFailure(e);
-            }
-        }
-
-        @Override
-        public void onFailure(Exception e) {
-            completionListener.onFailure(e);
-        }
-    }
-
+    /**
+     * DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around
+     * the encrypted object
+     * @param <T> Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
+     * @param <U> Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
+     */
     static class DecryptedReadContext<T, U> extends ReadContext {
 
-        private final U cryptoContext;
         private final CryptoHandler<T, U> cryptoHandler;
-        private final long fileSize;
+        private final U cryptoContext;
+        private Long blobSize;
 
-        public DecryptedReadContext(
-            ReadContext readContext,
-            CryptoHandler<T, U> cryptoHandler,
-            EncryptedHeaderContentSupplier headerContentSupplier
-        ) {
+        public DecryptedReadContext(ReadContext readContext, CryptoHandler<T, U> cryptoHandler, U cryptoContext) {
             super(readContext);
-            try {
-                this.cryptoHandler = cryptoHandler;
-                this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier);
-                this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize());
-            } catch (IOException e) {
-                throw new RuntimeException(e);
-            }
+            this.cryptoHandler = cryptoHandler;
+            this.cryptoContext = cryptoContext;
         }
 
         @Override
         public long getBlobSize() {
-            return fileSize;
+            // initializes the value lazily
+            if (blobSize == null) {
+                this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize());
+            }
+            return this.blobSize;
         }
 
         @Override
         public List<InputStreamContainer> getPartStreams() {
-            return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList());
+            return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
         }
 
-        private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) {
+        /**
+         * Transforms an encrypted {@link InputStreamContainer} to a decrypted instance
+         * @param inputStreamContainer encrypted input stream container instance
+         * @return decrypted input stream container instance
+         */
+        private InputStreamContainer decryptInputStreamContainer(InputStreamContainer inputStreamContainer) {
             long startOfStream = inputStreamContainer.getOffset();
             long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1;
             DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange(
@@ -202,11 +164,9 @@ private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer in
 
             long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0];
             long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1;
-            return new InputStreamContainer(
-                decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()),
-                adjustedPos,
-                adjustedLength
-            );
+            final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider()
+                .apply(inputStreamContainer.getInputStream());
+            return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos);
         }
     }
 }
diff --git a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java
index 475d891ea9336..d0933741339d9 100644
--- a/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java
+++ b/server/src/main/java/org/opensearch/common/blobstore/EncryptedBlobContainer.java
@@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException {
         return cryptoHandler.createDecryptingStream(inputStream);
     }
 
-    private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
+    EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
         return (start, end) -> {
             byte[] buffer;
             int length = (int) (end - start + 1);
diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java
new file mode 100644
index 0000000000000..947a4f9b1c9ab
--- /dev/null
+++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java
@@ -0,0 +1,121 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.common.blobstore;
+
+import org.opensearch.common.Randomness;
+import org.opensearch.common.blobstore.stream.read.ReadContext;
+import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils;
+import org.opensearch.common.crypto.CryptoHandler;
+import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
+import org.opensearch.common.io.InputStreamContainer;
+import org.opensearch.core.action.ActionListener;
+import org.opensearch.test.OpenSearchTestCase;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.List;
+import java.util.function.UnaryOperator;
+
+import org.mockito.Mockito;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase {
+
+    // Tests the happy path scenario for decrypting a read context
+    @SuppressWarnings("unchecked")
+    public void testReadBlobAsync() throws Exception {
+        String testBlobName = "testBlobName";
+        int size = 100;
+
+        // Mock objects needed for the test
+        AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
+        CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
+        Object cryptoContext = mock(Object.class);
+        when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext);
+        when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size);
+        long[] adjustedRanges = { 0, size - 1 };
+        DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity());
+        when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider);
+
+        // Objects needed for API call
+        final byte[] data = new byte[size];
+        Randomness.get().nextBytes(data);
+        final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
+        final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
+            new ListenerTestUtils.CountingCompletionListener<>();
+        final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);
+
+        Mockito.doAnswer(invocation -> {
+            ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
+            readContextActionListener.onResponse(readContext);
+            return null;
+        }).when(blobContainer).readBlobAsync(eq(testBlobName), any());
+
+        AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
+            new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
+        asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);
+
+        // Assert results
+        ReadContext response = completionListener.getResponse();
+        assertEquals(0, completionListener.getFailureCount());
+        assertEquals(1, completionListener.getResponseCount());
+        assertNull(completionListener.getException());
+
+        assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext);
+        assertEquals(1, response.getNumberOfParts());
+        assertEquals(size, response.getBlobSize());
+
+        InputStreamContainer responseContainer = response.getPartStreams().get(0);
+        assertEquals(0, responseContainer.getOffset());
+        assertEquals(size, responseContainer.getContentLength());
+        assertEquals(100, responseContainer.getInputStream().available());
+    }
+
+    // Tests the exception scenario for decrypting a read context
+    @SuppressWarnings("unchecked")
+    public void testReadBlobAsyncException() throws Exception {
+        String testBlobName = "testBlobName";
+        int size = 100;
+
+        // Mock objects needed for the test
+        AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
+        CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
+        when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException());
+
+        // Objects needed for API call
+        final byte[] data = new byte[size];
+        Randomness.get().nextBytes(data);
+        final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
+        final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
+            new ListenerTestUtils.CountingCompletionListener<>();
+        final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);
+
+        Mockito.doAnswer(invocation -> {
+            ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
+            readContextActionListener.onResponse(readContext);
+            return null;
+        }).when(blobContainer).readBlobAsync(eq(testBlobName), any());
+
+        AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
+            new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
+        asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);
+
+        // Assert results
+        assertEquals(1, completionListener.getFailureCount());
+        assertEquals(0, completionListener.getResponseCount());
+        assertNull(completionListener.getResponse());
+        assertTrue(completionListener.getException() instanceof IOException);
+    }
+
+}
diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java
index 1e9450c83e3ab..a3a32f6db2148 100644
--- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java
+++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java
@@ -19,7 +19,7 @@ public class ListenerTestUtils {
      * CountingCompletionListener acts as a verification instance for wrapping listener based calls.
      * Keeps track of the last response, failure and count of response and failure invocations.
      */
-    static class CountingCompletionListener<T> implements ActionListener<T> {
+    public static class CountingCompletionListener<T> implements ActionListener<T> {
         private int responseCount;
         private int failureCount;
         private T response;