diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3Context.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3Context.java index 3000460997ad9..eeb33295d8329 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3Context.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3Context.java @@ -14,6 +14,7 @@ package io.trino.filesystem.s3; import io.trino.filesystem.s3.S3FileSystemConfig.ObjectCannedAcl; +import io.trino.filesystem.s3.S3FileSystemConfig.S3SseType; import io.trino.spi.security.ConnectorIdentity; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; @@ -85,18 +86,28 @@ public void applyCredentialProviderOverride(AwsRequestOverrideConfiguration.Buil credentialsProviderOverride.ifPresent(builder::credentialsProvider); } - record S3SseContext(S3FileSystemConfig.S3SseType sseType, Optional sseKmsKeyId) + record S3SseContext(S3SseType sseType, Optional sseKmsKeyId, Optional sseCustomerKey) { - public S3SseContext + S3SseContext { requireNonNull(sseType, "sseType is null"); requireNonNull(sseKmsKeyId, "sseKmsKeyId is null"); - checkArgument((sseType != KMS) || (sseKmsKeyId.isPresent()), "sseKmsKeyId is missing for SSE-KMS"); + requireNonNull(sseCustomerKey, "sseCustomerKey is null"); + switch (sseType) { + case KMS -> checkArgument(sseKmsKeyId.isPresent(), "sseKmsKeyId is missing for SSE-KMS"); + case CUSTOMER -> checkArgument(sseCustomerKey.isPresent(), "sseCustomerKey is missing for SSE-C"); + case NONE, S3 -> {} + } + } + + public static S3SseContext of(S3SseType sseType, String sseKmsKeyId, String sseCustomerKey) + { + return new S3SseContext(sseType, Optional.ofNullable(sseKmsKeyId), Optional.ofNullable(sseCustomerKey).map(S3SseCustomerKey::onAes256)); } public static S3SseContext withKmsKeyId(String kmsKeyId) { - return new S3SseContext(KMS, Optional.ofNullable(kmsKeyId)); + return new S3SseContext(KMS, Optional.ofNullable(kmsKeyId), Optional.empty()); } } } diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystem.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystem.java index 236d05765bbc7..c4ac486b61dbb 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystem.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystem.java @@ -54,11 +54,14 @@ import java.util.concurrent.Executor; import java.util.stream.Stream; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.partition; import static com.google.common.collect.Multimaps.toMultimap; +import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE; import static io.trino.filesystem.s3.S3SseCUtils.encoded; import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum; +import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; @@ -341,16 +344,20 @@ public Optional encryptedPreSignedUri(Location location, Duration t location.verifyValidFileLocation(); S3Location s3Location = new S3Location(location); + verify(key.isEmpty() || context.s3SseContext().sseType() == NONE, "Encryption key cannot be used with SSE configuration"); + GetObjectRequest request = GetObjectRequest.builder() .overrideConfiguration(context::applyCredentialProviderOverride) .requestPayer(requestPayer) .key(s3Location.key()) .bucket(s3Location.bucket()) - .applyMutation(builder -> key.ifPresent(encryption -> { - builder.sseCustomerKeyMD5(md5Checksum(encryption)); - builder.sseCustomerAlgorithm(encryption.algorithm()); - builder.sseCustomerKey(encoded(encryption)); - })) + .applyMutation(builder -> + key.ifPresentOrElse( + encryption -> + builder.sseCustomerKeyMD5(md5Checksum(encryption)) + .sseCustomerAlgorithm(encryption.algorithm()) + .sseCustomerKey(encoded(encryption)), + () -> setEncryptionSettings(builder, context.s3SseContext()))) .build(); GetObjectPresignRequest preSignRequest = GetObjectPresignRequest.builder() diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemConfig.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemConfig.java index 9b22a81d51bab..cf340d78645e5 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemConfig.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemConfig.java @@ -23,6 +23,7 @@ import io.airlift.units.Duration; import io.airlift.units.MaxDataSize; import io.airlift.units.MinDataSize; +import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; import software.amazon.awssdk.retries.api.RetryStrategy; @@ -41,7 +42,7 @@ public class S3FileSystemConfig { public enum S3SseType { - NONE, S3, KMS + NONE, S3, KMS, CUSTOMER } public enum ObjectCannedAcl @@ -96,6 +97,7 @@ public static RetryStrategy getRetryStrategy(RetryMode retryMode) private String stsRegion; private S3SseType sseType = S3SseType.NONE; private String sseKmsKeyId; + private String sseCustomerKey; private boolean useWebIdentityTokenCredentialsProvider; private DataSize streamingPartSize = DataSize.of(16, MEGABYTE); private boolean requesterPays; @@ -320,6 +322,29 @@ public S3FileSystemConfig setUseWebIdentityTokenCredentialsProvider(boolean useW return this; } + public String getSseCustomerKey() + { + return sseCustomerKey; + } + + @Config("s3.sse.customer-key") + @ConfigDescription("Customer Key to use for S3 server-side encryption with Customer key (SSE-C)") + @ConfigSecuritySensitive + public S3FileSystemConfig setSseCustomerKey(String sseCustomerKey) + { + this.sseCustomerKey = sseCustomerKey; + return this; + } + + @AssertTrue(message = "s3.sse.customer-key has to be set for server-side encryption with customer-provided key") + public boolean isSseWithCustomerKeyConfigValid() + { + if (sseType == S3SseType.CUSTOMER) { + return sseCustomerKey != null; + } + return true; + } + @NotNull @MinDataSize("5MB") @MaxDataSize("256MB") diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemLoader.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemLoader.java index 93f72eab6c08d..e9e92c4e2e295 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemLoader.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3FileSystemLoader.java @@ -89,9 +89,10 @@ private S3FileSystemLoader(Optional mappingProvider, this.context = new S3Context( toIntExact(config.getStreamingPartSize().toBytes()), config.isRequesterPays(), - new S3SseContext( + S3SseContext.of( config.getSseType(), - Optional.ofNullable(config.getSseKmsKeyId())), + config.getSseKmsKeyId(), + config.getSseCustomerKey()), Optional.empty(), config.getCannedAcl(), config.isSupportsExclusiveCreate()); diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3InputFile.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3InputFile.java index b6bddd2d88f0a..f9a2f77f41e0a 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3InputFile.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3InputFile.java @@ -32,8 +32,11 @@ import java.time.Instant; import java.util.Optional; +import static com.google.common.base.Verify.verify; +import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE; import static io.trino.filesystem.s3.S3SseCUtils.encoded; import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum; +import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings; import static java.util.Objects.requireNonNull; final class S3InputFile @@ -57,6 +60,8 @@ public S3InputFile(S3Client client, S3Context context, S3Location location, Long this.lastModified = lastModified; this.key = requireNonNull(key, "key is null"); location.location().verifyValidFileLocation(); + + verify(key.isEmpty() || context.s3SseContext().sseType() == NONE, "Encryption key cannot be used with SSE configuration"); } @Override @@ -111,11 +116,13 @@ private GetObjectRequest newGetObjectRequest() .requestPayer(requestPayer) .bucket(location.bucket()) .key(location.key()) - .applyMutation(builder -> key.ifPresent(encryption -> { - builder.sseCustomerKey(encoded(encryption)); - builder.sseCustomerAlgorithm(encryption.algorithm()); - builder.sseCustomerKeyMD5(md5Checksum(encryption)); - })) + .applyMutation(builder -> + key.ifPresentOrElse( + encryption -> + builder.sseCustomerKey(encoded(encryption)) + .sseCustomerAlgorithm(encryption.algorithm()) + .sseCustomerKeyMD5(md5Checksum(encryption)), + () -> setEncryptionSettings(builder, context.s3SseContext()))) .build(); } @@ -127,11 +134,13 @@ private boolean headObject() .requestPayer(requestPayer) .bucket(location.bucket()) .key(location.key()) - .applyMutation(builder -> key.ifPresent(encryption -> { - builder.sseCustomerKey(encoded(encryption)); - builder.sseCustomerAlgorithm(encryption.algorithm()); - builder.sseCustomerKeyMD5(md5Checksum(encryption)); - })) + .applyMutation(builder -> + key.ifPresentOrElse( + encryption -> + builder.sseCustomerKey(encoded(encryption)) + .sseCustomerAlgorithm(encryption.algorithm()) + .sseCustomerKeyMD5(md5Checksum(encryption)), + () -> setEncryptionSettings(builder, context.s3SseContext()))) .build(); try { diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java index 3377f716d9ab4..f573fe9ae46a4 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3OutputStream.java @@ -48,7 +48,7 @@ import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE; import static io.trino.filesystem.s3.S3SseCUtils.encoded; import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum; -import static io.trino.filesystem.s3.S3SseRequestConfigurator.addEncryptionSettings; +import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings; import static java.lang.Math.clamp; import static java.lang.Math.max; import static java.lang.Math.min; @@ -224,7 +224,7 @@ private void flushBuffer(boolean finished) builder.sseCustomerAlgorithm(encryption.algorithm()); builder.sseCustomerKeyMD5(md5Checksum(encryption)); }); - addEncryptionSettings(builder, context.s3SseContext()); + setEncryptionSettings(builder, context.s3SseContext()); }) .build(); @@ -304,14 +304,13 @@ private CompletedPart uploadPage(byte[] data, int length) .requestPayer(requestPayer) .bucket(location.bucket()) .key(location.key()) - .applyMutation(builder -> { - key.ifPresent(encryption -> { - builder.sseCustomerKey(encoded(encryption)); - builder.sseCustomerAlgorithm(encryption.algorithm()); - builder.sseCustomerKeyMD5(md5Checksum(encryption)); - }); - addEncryptionSettings(builder, context.s3SseContext()); - }) + .applyMutation(builder -> + key.ifPresentOrElse( + encryption -> + builder.sseCustomerKey(encoded(encryption)) + .sseCustomerAlgorithm(encryption.algorithm()) + .sseCustomerKeyMD5(md5Checksum(encryption)), + () -> setEncryptionSettings(builder, context.s3SseContext()))) .build(); uploadId = Optional.of(client.createMultipartUpload(request).uploadId()); @@ -326,11 +325,13 @@ private CompletedPart uploadPage(byte[] data, int length) .contentLength((long) length) .uploadId(uploadId.get()) .partNumber(currentPartNumber) - .applyMutation(builder -> key.ifPresent(encryption -> { - builder.sseCustomerKey(encoded(encryption)); - builder.sseCustomerAlgorithm(encryption.algorithm()); - builder.sseCustomerKeyMD5(md5Checksum(encryption)); - })) + .applyMutation(builder -> + key.ifPresentOrElse( + encryption -> + builder.sseCustomerKey(encoded(encryption)) + .sseCustomerAlgorithm(encryption.algorithm()) + .sseCustomerKeyMD5(md5Checksum(encryption)), + () -> setEncryptionSettings(builder, context.s3SseContext()))) .build(); ByteBuffer bytes = ByteBuffer.wrap(data, 0, length); @@ -356,11 +357,12 @@ private void finishUpload(String uploadId) .uploadId(uploadId) .multipartUpload(x -> x.parts(parts)) .applyMutation(builder -> { - key.ifPresent(encodingKey -> { - builder.sseCustomerKey(encoded(encodingKey)); - builder.sseCustomerAlgorithm(encodingKey.algorithm()); - builder.sseCustomerKeyMD5(md5Checksum(encodingKey)); - }); + key.ifPresentOrElse( + encryption -> + builder.sseCustomerKey(encoded(encryption)) + .sseCustomerAlgorithm(encryption.algorithm()) + .sseCustomerKeyMD5(md5Checksum(encryption)), + () -> setEncryptionSettings(builder, context.s3SseContext())); if (exclusiveCreate) { builder.ifNoneMatch("*"); } diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseCustomerKey.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseCustomerKey.java new file mode 100644 index 0000000000000..9c9a4be3bed26 --- /dev/null +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseCustomerKey.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.filesystem.s3; + +import static java.util.Objects.requireNonNull; +import static software.amazon.awssdk.utils.BinaryUtils.fromBase64; +import static software.amazon.awssdk.utils.Md5Utils.md5AsBase64; + +public record S3SseCustomerKey(String key, String md5, String algorithm) +{ + private static final String SSE_C_ALGORITHM = "AES256"; + + public S3SseCustomerKey + { + requireNonNull(key, "key is null"); + requireNonNull(md5, "md5 is null"); + requireNonNull(algorithm, "algorithm is null"); + } + + public static S3SseCustomerKey onAes256(String key) + { + return new S3SseCustomerKey(key, md5AsBase64(fromBase64(key)), SSE_C_ALGORITHM); + } +} diff --git a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseRequestConfigurator.java b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseRequestConfigurator.java index 9bb50c94bfdf9..8559117a187c8 100644 --- a/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseRequestConfigurator.java +++ b/lib/trino-filesystem-s3/src/main/java/io/trino/filesystem/s3/S3SseRequestConfigurator.java @@ -14,9 +14,14 @@ package io.trino.filesystem.s3; import io.trino.filesystem.s3.S3Context.S3SseContext; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.CUSTOMER; import static software.amazon.awssdk.services.s3.model.ServerSideEncryption.AES256; import static software.amazon.awssdk.services.s3.model.ServerSideEncryption.AWS_KMS; @@ -24,21 +29,73 @@ public final class S3SseRequestConfigurator { private S3SseRequestConfigurator() {} - public static void addEncryptionSettings(PutObjectRequest.Builder builder, S3SseContext context) + public static void setEncryptionSettings(PutObjectRequest.Builder builder, S3SseContext context) { switch (context.sseType()) { case NONE -> { /* ignored */ } case S3 -> builder.serverSideEncryption(AES256); case KMS -> context.sseKmsKeyId().ifPresent(builder.serverSideEncryption(AWS_KMS)::ssekmsKeyId); + case CUSTOMER -> { + context.sseCustomerKey().ifPresent(s3SseCustomerKey -> + builder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()) + .sseCustomerKey(s3SseCustomerKey.key()) + .sseCustomerKeyMD5(s3SseCustomerKey.md5())); + } } } - public static void addEncryptionSettings(CreateMultipartUploadRequest.Builder builder, S3SseContext context) + public static void setEncryptionSettings(CreateMultipartUploadRequest.Builder builder, S3SseContext context) { switch (context.sseType()) { case NONE -> { /* ignored */ } case S3 -> builder.serverSideEncryption(AES256); case KMS -> context.sseKmsKeyId().ifPresent(builder.serverSideEncryption(AWS_KMS)::ssekmsKeyId); + case CUSTOMER -> { + context.sseCustomerKey().ifPresent(s3SseCustomerKey -> + builder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()) + .sseCustomerKey(s3SseCustomerKey.key()) + .sseCustomerKeyMD5(s3SseCustomerKey.md5())); + } + } + } + + public static void setEncryptionSettings(CompleteMultipartUploadRequest.Builder builder, S3SseContext context) + { + if (context.sseType() == CUSTOMER) { + context.sseCustomerKey().ifPresent(s3SseCustomerKey -> + builder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()) + .sseCustomerKey(s3SseCustomerKey.key()) + .sseCustomerKeyMD5(s3SseCustomerKey.md5())); + } + } + + public static void setEncryptionSettings(GetObjectRequest.Builder builder, S3SseContext context) + { + if (context.sseType().equals(CUSTOMER)) { + context.sseCustomerKey().ifPresent(s3SseCustomerKey -> + builder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()) + .sseCustomerKey(s3SseCustomerKey.key()) + .sseCustomerKeyMD5(s3SseCustomerKey.md5())); + } + } + + public static void setEncryptionSettings(HeadObjectRequest.Builder builder, S3SseContext context) + { + if (context.sseType().equals(CUSTOMER)) { + context.sseCustomerKey().ifPresent(s3SseCustomerKey -> + builder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()) + .sseCustomerKey(s3SseCustomerKey.key()) + .sseCustomerKeyMD5(s3SseCustomerKey.md5())); + } + } + + public static void setEncryptionSettings(UploadPartRequest.Builder builder, S3SseContext context) + { + if (context.sseType() == CUSTOMER) { + context.sseCustomerKey().ifPresent(s3SseCustomerKey -> + builder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()) + .sseCustomerKey(s3SseCustomerKey.key()) + .sseCustomerKeyMD5(s3SseCustomerKey.md5())); } } } diff --git a/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemAwsS3WithSseCustomerKey.java b/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemAwsS3WithSseCustomerKey.java new file mode 100644 index 0000000000000..7dcdaba12f8ee --- /dev/null +++ b/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemAwsS3WithSseCustomerKey.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.filesystem.s3; + +import io.airlift.units.DataSize; +import io.opentelemetry.api.OpenTelemetry; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.DelegatingS3Client; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Request; +import software.amazon.awssdk.utils.BinaryUtils; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; + +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.function.Function; + +import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.CUSTOMER; +import static java.util.Objects.requireNonNull; + +public class TestS3FileSystemAwsS3WithSseCustomerKey + extends AbstractTestS3FileSystem +{ + private static final String CUSTOMER_KEY = generateCustomerKey(); + + private String accessKey; + private String secretKey; + private String region; + private String bucket; + private S3SseCustomerKey s3SseCustomerKey; + + @Override + protected void initEnvironment() + { + accessKey = environmentVariable("AWS_ACCESS_KEY_ID"); + secretKey = environmentVariable("AWS_SECRET_ACCESS_KEY"); + region = environmentVariable("AWS_REGION"); + bucket = environmentVariable("EMPTY_S3_BUCKET"); + s3SseCustomerKey = S3SseCustomerKey.onAes256(CUSTOMER_KEY); + } + + @Override + protected String bucket() + { + return bucket; + } + + @Override + protected S3Client createS3Client() + { + S3Client s3Client = S3Client.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKey, secretKey))) + .region(Region.of(region)) + .build(); + + return new DelegatingS3Client(s3Client) + { + @Override + protected ReturnT invokeOperation(T request, Function operation) + { + if (request instanceof PutObjectRequest putObjectRequest) { + PutObjectRequest.Builder putObjectRequestBuilder = putObjectRequest.toBuilder(); + putObjectRequestBuilder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()); + putObjectRequestBuilder.sseCustomerKey(s3SseCustomerKey.key()); + putObjectRequestBuilder.sseCustomerKeyMD5(s3SseCustomerKey.md5()); + return operation.apply((T) putObjectRequestBuilder.build()); + } + else if (request instanceof GetObjectRequest getObjectRequest) { + GetObjectRequest.Builder getObjectRequestBuilder = getObjectRequest.toBuilder(); + getObjectRequestBuilder.sseCustomerAlgorithm(s3SseCustomerKey.algorithm()); + getObjectRequestBuilder.sseCustomerKey(s3SseCustomerKey.key()); + getObjectRequestBuilder.sseCustomerKeyMD5(s3SseCustomerKey.md5()); + return operation.apply((T) getObjectRequestBuilder.build()); + } + return operation.apply(request); + } + }; + } + + @Override + protected S3FileSystemFactory createS3FileSystemFactory() + { + return new S3FileSystemFactory( + OpenTelemetry.noop(), + new S3FileSystemConfig() + .setAwsAccessKey(accessKey) + .setAwsSecretKey(secretKey) + .setRegion(region) + .setSseType(CUSTOMER) + .setSseCustomerKey(s3SseCustomerKey.key()) + .setStreamingPartSize(DataSize.valueOf("5.5MB")), + new S3FileSystemStats()); + } + + private static String environmentVariable(String name) + { + return requireNonNull(System.getenv(name), "Environment variable not set: " + name); + } + + private static String generateCustomerKey() + { + try { + KeyGenerator keyGenerator = KeyGenerator.getInstance("AES"); + keyGenerator.init(256, new SecureRandom()); + SecretKey secretKey = keyGenerator.generateKey(); + return BinaryUtils.toBase64(secretKey.getEncoded()); + } + catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } +} diff --git a/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemConfig.java b/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemConfig.java index 086fdb13558f7..5fa6c3b86d621 100644 --- a/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemConfig.java +++ b/lib/trino-filesystem-s3/src/test/java/io/trino/filesystem/s3/TestS3FileSystemConfig.java @@ -19,6 +19,7 @@ import io.airlift.units.Duration; import io.trino.filesystem.s3.S3FileSystemConfig.ObjectCannedAcl; import io.trino.filesystem.s3.S3FileSystemConfig.S3SseType; +import jakarta.validation.constraints.AssertTrue; import org.junit.jupiter.api.Test; import java.util.Map; @@ -26,6 +27,7 @@ import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static io.airlift.testing.ValidationAssertions.assertFailsValidation; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.filesystem.s3.S3FileSystemConfig.RetryMode.LEGACY; import static io.trino.filesystem.s3.S3FileSystemConfig.RetryMode.STANDARD; @@ -53,6 +55,7 @@ public void testDefaults() .setMaxErrorRetries(10) .setSseKmsKeyId(null) .setUseWebIdentityTokenCredentialsProvider(false) + .setSseCustomerKey(null) .setStreamingPartSize(DataSize.of(16, MEGABYTE)) .setRequesterPays(false) .setMaxConnections(500) @@ -89,6 +92,7 @@ public void testExplicitPropertyMappings() .put("s3.max-error-retries", "12") .put("s3.sse.type", "KMS") .put("s3.sse.kms-key-id", "mykey") + .put("s3.sse.customer-key", "customerKey") .put("s3.use-web-identity-token-credentials-provider", "true") .put("s3.streaming.part-size", "42MB") .put("s3.requester-pays", "true") @@ -125,6 +129,7 @@ public void testExplicitPropertyMappings() .setSseType(S3SseType.KMS) .setSseKmsKeyId("mykey") .setUseWebIdentityTokenCredentialsProvider(true) + .setSseCustomerKey("customerKey") .setRequesterPays(true) .setMaxConnections(42) .setConnectionTtl(new Duration(1, MINUTES)) @@ -142,4 +147,14 @@ public void testExplicitPropertyMappings() assertFullMapping(properties, expected); } + + @Test + public void testSSEWithCustomerKeyValidation() + { + assertFailsValidation(new S3FileSystemConfig() + .setSseType(S3SseType.CUSTOMER), + "sseWithCustomerKeyConfigValid", + "s3.sse.customer-key has to be set for server-side encryption with customer-provided key", + AssertTrue.class); + } }