Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for encrypted async blob read #10131

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681))
- Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694))
- [Telemetry-Otel] Added support for OtlpGrpcSpanExporter exporter ([#9666](https://github.com/opensearch-project/OpenSearch/pull/9666))
- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131))

### Dependencies
- Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;
import java.io.InputStream;
import java.util.List;
import java.util.stream.Collectors;

/**
* EncryptedBlobContainer is an encrypted BlobContainer that is backed by a
Expand All @@ -44,12 +46,17 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp

@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
throw new UnsupportedOperationException();
}

@Override
public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
throw new UnsupportedOperationException();
try {
final U cryptoContext = cryptoHandler.loadEncryptionMetadata(getEncryptedHeaderContentSupplier(blobName));
ActionListener<ReadContext> decryptingCompletionListener = ActionListener.map(
listener,
readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, cryptoContext)
);

blobContainer.readBlobAsync(blobName, decryptingCompletionListener);
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
Expand Down Expand Up @@ -108,4 +115,58 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {
}

}

/**
* DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around
* the encrypted object
* @param <T> Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
* @param <U> Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
*/
static class DecryptedReadContext<T, U> extends ReadContext {

private final CryptoHandler<T, U> cryptoHandler;
private final U cryptoContext;
private Long blobSize;

public DecryptedReadContext(ReadContext readContext, CryptoHandler<T, U> cryptoHandler, U cryptoContext) {
super(readContext);
this.cryptoHandler = cryptoHandler;
this.cryptoContext = cryptoContext;
}

@Override
public long getBlobSize() {
// initializes the value lazily
if (blobSize == null) {
this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize());
}
return this.blobSize;
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
}

/**
* Transforms an encrypted {@link InputStreamContainer} to a decrypted instance
* @param inputStreamContainer encrypted input stream container instance
* @return decrypted input stream container instance
*/
private InputStreamContainer decryptInputStreamContainer(InputStreamContainer inputStreamContainer) {
long startOfStream = inputStreamContainer.getOffset();
long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1;
DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange(
cryptoContext,
startOfStream,
endOfStream
);

long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0];
long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1;
final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider()
.apply(inputStreamContainer.getInputStream());
return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException {
return cryptoHandler.createDecryptingStream(inputStream);
}

private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
return (start, end) -> {
byte[] buffer;
int length = (int) (end - start + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.blobChecksum = readContext.blobChecksum;
}

public String getBlobChecksum() {
return blobChecksum;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.blobstore;

import org.opensearch.common.Randomness;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.function.UnaryOperator;

import org.mockito.Mockito;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase {

// Tests the happy path scenario for decrypting a read context
@SuppressWarnings("unchecked")
public void testReadBlobAsync() throws Exception {
String testBlobName = "testBlobName";
int size = 100;

// Mock objects needed for the test
AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
Object cryptoContext = mock(Object.class);
when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext);
when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size);
long[] adjustedRanges = { 0, size - 1 };
DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity());
when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider);

// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
readContextActionListener.onResponse(readContext);
return null;
}).when(blobContainer).readBlobAsync(eq(testBlobName), any());

AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);

// Assert results
ReadContext response = completionListener.getResponse();
assertEquals(0, completionListener.getFailureCount());
assertEquals(1, completionListener.getResponseCount());
assertNull(completionListener.getException());

assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext);
assertEquals(1, response.getNumberOfParts());
assertEquals(size, response.getBlobSize());

InputStreamContainer responseContainer = response.getPartStreams().get(0);
assertEquals(0, responseContainer.getOffset());
assertEquals(size, responseContainer.getContentLength());
assertEquals(100, responseContainer.getInputStream().available());
}

// Tests the exception scenario for decrypting a read context
@SuppressWarnings("unchecked")
public void testReadBlobAsyncException() throws Exception {
String testBlobName = "testBlobName";
int size = 100;

// Mock objects needed for the test
AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException());

// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
readContextActionListener.onResponse(readContext);
return null;
}).when(blobContainer).readBlobAsync(eq(testBlobName), any());

AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);

// Assert results
assertEquals(1, completionListener.getFailureCount());
assertEquals(0, completionListener.getResponseCount());
assertNull(completionListener.getResponse());
assertTrue(completionListener.getException() instanceof IOException);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class ListenerTestUtils {
* CountingCompletionListener acts as a verification instance for wrapping listener based calls.
* Keeps track of the last response, failure and count of response and failure invocations.
*/
static class CountingCompletionListener<T> implements ActionListener<T> {
public static class CountingCompletionListener<T> implements ActionListener<T> {
private int responseCount;
private int failureCount;
private T response;
Expand Down
Loading