Skip to content

Commit

Permalink
Allow customer key encryption type for S3
Browse files Browse the repository at this point in the history
  • Loading branch information
marcinsbd authored and wendigo committed Nov 19, 2024
1 parent d9cb746 commit 7aae3f7
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.filesystem.s3;

import io.trino.filesystem.s3.S3FileSystemConfig.ObjectCannedAcl;
import io.trino.filesystem.s3.S3FileSystemConfig.S3SseType;
import io.trino.spi.security.ConnectorIdentity;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
Expand Down Expand Up @@ -85,18 +86,28 @@ public void applyCredentialProviderOverride(AwsRequestOverrideConfiguration.Buil
credentialsProviderOverride.ifPresent(builder::credentialsProvider);
}

record S3SseContext(S3FileSystemConfig.S3SseType sseType, Optional<String> sseKmsKeyId)
record S3SseContext(S3SseType sseType, Optional<String> sseKmsKeyId, Optional<S3SseCustomerKey> sseCustomerKey)
{
public S3SseContext
S3SseContext
{
requireNonNull(sseType, "sseType is null");
requireNonNull(sseKmsKeyId, "sseKmsKeyId is null");
checkArgument((sseType != KMS) || (sseKmsKeyId.isPresent()), "sseKmsKeyId is missing for SSE-KMS");
requireNonNull(sseCustomerKey, "sseCustomerKey is null");
switch (sseType) {
case KMS -> checkArgument(sseKmsKeyId.isPresent(), "sseKmsKeyId is missing for SSE-KMS");
case CUSTOMER -> checkArgument(sseCustomerKey.isPresent(), "sseCustomerKey is missing for SSE-C");
case NONE, S3 -> {}
}
}

public static S3SseContext of(S3SseType sseType, String sseKmsKeyId, String sseCustomerKey)
{
return new S3SseContext(sseType, Optional.ofNullable(sseKmsKeyId), Optional.ofNullable(sseCustomerKey).map(S3SseCustomerKey::onAes256));
}

public static S3SseContext withKmsKeyId(String kmsKeyId)
{
return new S3SseContext(KMS, Optional.ofNullable(kmsKeyId));
return new S3SseContext(KMS, Optional.ofNullable(kmsKeyId), Optional.empty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@
import java.util.concurrent.Executor;
import java.util.stream.Stream;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.partition;
import static com.google.common.collect.Multimaps.toMultimap;
import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE;
import static io.trino.filesystem.s3.S3SseCUtils.encoded;
import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum;
import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toMap;

Expand Down Expand Up @@ -341,16 +344,20 @@ public Optional<UriLocation> encryptedPreSignedUri(Location location, Duration t
location.verifyValidFileLocation();
S3Location s3Location = new S3Location(location);

verify(key.isEmpty() || context.s3SseContext().sseType() == NONE, "Encryption key cannot be used with SSE configuration");

GetObjectRequest request = GetObjectRequest.builder()
.overrideConfiguration(context::applyCredentialProviderOverride)
.requestPayer(requestPayer)
.key(s3Location.key())
.bucket(s3Location.bucket())
.applyMutation(builder -> key.ifPresent(encryption -> {
builder.sseCustomerKeyMD5(md5Checksum(encryption));
builder.sseCustomerAlgorithm(encryption.algorithm());
builder.sseCustomerKey(encoded(encryption));
}))
.applyMutation(builder ->
key.ifPresentOrElse(
encryption ->
builder.sseCustomerKeyMD5(md5Checksum(encryption))
.sseCustomerAlgorithm(encryption.algorithm())
.sseCustomerKey(encoded(encryption)),
() -> setEncryptionSettings(builder, context.s3SseContext())))
.build();

GetObjectPresignRequest preSignRequest = GetObjectPresignRequest.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.airlift.units.Duration;
import io.airlift.units.MaxDataSize;
import io.airlift.units.MinDataSize;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import software.amazon.awssdk.retries.api.RetryStrategy;
Expand All @@ -41,7 +42,7 @@ public class S3FileSystemConfig
{
public enum S3SseType
{
NONE, S3, KMS
NONE, S3, KMS, CUSTOMER
}

public enum ObjectCannedAcl
Expand Down Expand Up @@ -96,6 +97,7 @@ public static RetryStrategy getRetryStrategy(RetryMode retryMode)
private String stsRegion;
private S3SseType sseType = S3SseType.NONE;
private String sseKmsKeyId;
private String sseCustomerKey;
private boolean useWebIdentityTokenCredentialsProvider;
private DataSize streamingPartSize = DataSize.of(16, MEGABYTE);
private boolean requesterPays;
Expand Down Expand Up @@ -320,6 +322,29 @@ public S3FileSystemConfig setUseWebIdentityTokenCredentialsProvider(boolean useW
return this;
}

public String getSseCustomerKey()
{
return sseCustomerKey;
}

@Config("s3.sse.customer-key")
@ConfigDescription("Customer Key to use for S3 server-side encryption with Customer key (SSE-C)")
@ConfigSecuritySensitive
public S3FileSystemConfig setSseCustomerKey(String sseCustomerKey)
{
this.sseCustomerKey = sseCustomerKey;
return this;
}

@AssertTrue(message = "s3.sse.customer-key has to be set for server-side encryption with customer-provided key")
public boolean isSseWithCustomerKeyConfigValid()
{
if (sseType == S3SseType.CUSTOMER) {
return sseCustomerKey != null;
}
return true;
}

@NotNull
@MinDataSize("5MB")
@MaxDataSize("256MB")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ private S3FileSystemLoader(Optional<S3SecurityMappingProvider> mappingProvider,
this.context = new S3Context(
toIntExact(config.getStreamingPartSize().toBytes()),
config.isRequesterPays(),
new S3SseContext(
S3SseContext.of(
config.getSseType(),
Optional.ofNullable(config.getSseKmsKeyId())),
config.getSseKmsKeyId(),
config.getSseCustomerKey()),
Optional.empty(),
config.getCannedAcl(),
config.isSupportsExclusiveCreate());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
import java.time.Instant;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE;
import static io.trino.filesystem.s3.S3SseCUtils.encoded;
import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum;
import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings;
import static java.util.Objects.requireNonNull;

final class S3InputFile
Expand All @@ -57,6 +60,8 @@ public S3InputFile(S3Client client, S3Context context, S3Location location, Long
this.lastModified = lastModified;
this.key = requireNonNull(key, "key is null");
location.location().verifyValidFileLocation();

verify(key.isEmpty() || context.s3SseContext().sseType() == NONE, "Encryption key cannot be used with SSE configuration");
}

@Override
Expand Down Expand Up @@ -111,11 +116,13 @@ private GetObjectRequest newGetObjectRequest()
.requestPayer(requestPayer)
.bucket(location.bucket())
.key(location.key())
.applyMutation(builder -> key.ifPresent(encryption -> {
builder.sseCustomerKey(encoded(encryption));
builder.sseCustomerAlgorithm(encryption.algorithm());
builder.sseCustomerKeyMD5(md5Checksum(encryption));
}))
.applyMutation(builder ->
key.ifPresentOrElse(
encryption ->
builder.sseCustomerKey(encoded(encryption))
.sseCustomerAlgorithm(encryption.algorithm())
.sseCustomerKeyMD5(md5Checksum(encryption)),
() -> setEncryptionSettings(builder, context.s3SseContext())))
.build();
}

Expand All @@ -127,11 +134,13 @@ private boolean headObject()
.requestPayer(requestPayer)
.bucket(location.bucket())
.key(location.key())
.applyMutation(builder -> key.ifPresent(encryption -> {
builder.sseCustomerKey(encoded(encryption));
builder.sseCustomerAlgorithm(encryption.algorithm());
builder.sseCustomerKeyMD5(md5Checksum(encryption));
}))
.applyMutation(builder ->
key.ifPresentOrElse(
encryption ->
builder.sseCustomerKey(encoded(encryption))
.sseCustomerAlgorithm(encryption.algorithm())
.sseCustomerKeyMD5(md5Checksum(encryption)),
() -> setEncryptionSettings(builder, context.s3SseContext())))
.build();

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE;
import static io.trino.filesystem.s3.S3SseCUtils.encoded;
import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum;
import static io.trino.filesystem.s3.S3SseRequestConfigurator.addEncryptionSettings;
import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings;
import static java.lang.Math.clamp;
import static java.lang.Math.max;
import static java.lang.Math.min;
Expand Down Expand Up @@ -224,7 +224,7 @@ private void flushBuffer(boolean finished)
builder.sseCustomerAlgorithm(encryption.algorithm());
builder.sseCustomerKeyMD5(md5Checksum(encryption));
});
addEncryptionSettings(builder, context.s3SseContext());
setEncryptionSettings(builder, context.s3SseContext());
})
.build();

Expand Down Expand Up @@ -304,14 +304,13 @@ private CompletedPart uploadPage(byte[] data, int length)
.requestPayer(requestPayer)
.bucket(location.bucket())
.key(location.key())
.applyMutation(builder -> {
key.ifPresent(encryption -> {
builder.sseCustomerKey(encoded(encryption));
builder.sseCustomerAlgorithm(encryption.algorithm());
builder.sseCustomerKeyMD5(md5Checksum(encryption));
});
addEncryptionSettings(builder, context.s3SseContext());
})
.applyMutation(builder ->
key.ifPresentOrElse(
encryption ->
builder.sseCustomerKey(encoded(encryption))
.sseCustomerAlgorithm(encryption.algorithm())
.sseCustomerKeyMD5(md5Checksum(encryption)),
() -> setEncryptionSettings(builder, context.s3SseContext())))
.build();

uploadId = Optional.of(client.createMultipartUpload(request).uploadId());
Expand All @@ -326,11 +325,13 @@ private CompletedPart uploadPage(byte[] data, int length)
.contentLength((long) length)
.uploadId(uploadId.get())
.partNumber(currentPartNumber)
.applyMutation(builder -> key.ifPresent(encryption -> {
builder.sseCustomerKey(encoded(encryption));
builder.sseCustomerAlgorithm(encryption.algorithm());
builder.sseCustomerKeyMD5(md5Checksum(encryption));
}))
.applyMutation(builder ->
key.ifPresentOrElse(
encryption ->
builder.sseCustomerKey(encoded(encryption))
.sseCustomerAlgorithm(encryption.algorithm())
.sseCustomerKeyMD5(md5Checksum(encryption)),
() -> setEncryptionSettings(builder, context.s3SseContext())))
.build();

ByteBuffer bytes = ByteBuffer.wrap(data, 0, length);
Expand All @@ -356,11 +357,12 @@ private void finishUpload(String uploadId)
.uploadId(uploadId)
.multipartUpload(x -> x.parts(parts))
.applyMutation(builder -> {
key.ifPresent(encodingKey -> {
builder.sseCustomerKey(encoded(encodingKey));
builder.sseCustomerAlgorithm(encodingKey.algorithm());
builder.sseCustomerKeyMD5(md5Checksum(encodingKey));
});
key.ifPresentOrElse(
encryption ->
builder.sseCustomerKey(encoded(encryption))
.sseCustomerAlgorithm(encryption.algorithm())
.sseCustomerKeyMD5(md5Checksum(encryption)),
() -> setEncryptionSettings(builder, context.s3SseContext()));
if (exclusiveCreate) {
builder.ifNoneMatch("*");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.filesystem.s3;

import static java.util.Objects.requireNonNull;
import static software.amazon.awssdk.utils.BinaryUtils.fromBase64;
import static software.amazon.awssdk.utils.Md5Utils.md5AsBase64;

public record S3SseCustomerKey(String key, String md5, String algorithm)
{
private static final String SSE_C_ALGORITHM = "AES256";

public S3SseCustomerKey
{
requireNonNull(key, "key is null");
requireNonNull(md5, "md5 is null");
requireNonNull(algorithm, "algorithm is null");
}

public static S3SseCustomerKey onAes256(String key)
{
return new S3SseCustomerKey(key, md5AsBase64(fromBase64(key)), SSE_C_ALGORITHM);
}
}
Loading

0 comments on commit 7aae3f7

Please sign in to comment.