Skip to content

Commit

Permalink
Add support for encrypted async blob read
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <[email protected]>
  • Loading branch information
kotwanikunal committed Sep 22, 2023
1 parent cbff21d commit bb22c12
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.crypto.EncryptedHeaderContentSupplier;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;
import java.io.InputStream;
import java.util.List;
import java.util.stream.Collectors;

/**
* EncryptedBlobContainer is an encrypted BlobContainer that is backed by a
Expand All @@ -44,12 +47,24 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp

@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
throw new UnsupportedOperationException();
DecryptingReadContextListener<T, U> decryptingReadContextListener = new DecryptingReadContextListener<>(
listener,
cryptoHandler,
getEncryptedHeaderContentSupplier(blobName)
);
blobContainer.readBlobAsync(blobName, decryptingReadContextListener);
}

@Override
public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
throw new UnsupportedOperationException();
private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
return (start, end) -> {
byte[] buffer;
int length = (int) (end - start + 1);
try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) {
buffer = new byte[length];
inputStream.readNBytes(buffer, (int) start, buffer.length);
}
return buffer;
};
}

@Override
Expand Down Expand Up @@ -108,4 +123,90 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {
}

}

static class DecryptingReadContextListener<T, U> implements ActionListener<ReadContext> {

private final ActionListener<ReadContext> completionListener;
private final CryptoHandler<T, U> cryptoHandler;
private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier;

public DecryptingReadContextListener(
ActionListener<ReadContext> completionListener,
CryptoHandler<T, U> cryptoHandler,
EncryptedHeaderContentSupplier headerContentSupplier
) {
this.completionListener = completionListener;
this.cryptoHandler = cryptoHandler;
this.encryptedHeaderContentSupplier = headerContentSupplier;
}

@Override
public void onResponse(ReadContext readContext) {
try {
DecryptedReadContext<T, U> decryptedReadContext = new DecryptedReadContext<>(
readContext,
cryptoHandler,
encryptedHeaderContentSupplier
);
completionListener.onResponse(decryptedReadContext);
} catch (Exception e) {
onFailure(e);
}
}

@Override
public void onFailure(Exception e) {
completionListener.onFailure(e);
}
}

static class DecryptedReadContext<T, U> extends ReadContext {

private final U cryptoContext;
private final CryptoHandler<T, U> cryptoHandler;
private final long fileSize;

public DecryptedReadContext(
ReadContext readContext,
CryptoHandler<T, U> cryptoHandler,
EncryptedHeaderContentSupplier headerContentSupplier
) {
super(readContext);
try {
this.cryptoHandler = cryptoHandler;
this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier);
this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public long getBlobSize() {
return fileSize;
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList());
}

private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) {
long startOfStream = inputStreamContainer.getOffset();
long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1;
DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange(
cryptoContext,
startOfStream,
endOfStream
);

long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0];
long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1;
return new InputStreamContainer(
decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()),
adjustedPos,
adjustedLength
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.blobChecksum = readContext.blobChecksum;
}

public String getBlobChecksum() {
return blobChecksum;
}
Expand Down

0 comments on commit bb22c12

Please sign in to comment.