diff --git a/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/s3/express/S3ExpressIntegration.kt b/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/s3/express/S3ExpressIntegration.kt index 430c4f8ebc1..03b552141be 100644 --- a/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/s3/express/S3ExpressIntegration.kt +++ b/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/s3/express/S3ExpressIntegration.kt @@ -16,8 +16,6 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigProperty import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigPropertyType -import software.amazon.smithy.kotlin.codegen.utils.dq -import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.* import software.amazon.smithy.model.traits.* @@ -99,7 +97,6 @@ class S3ExpressIntegration : KotlinIntegration { resolved + listOf( addClientToExecutionContext, addBucketToExecutionContext, - useCrc32Checksum, uploadPartDisableChecksum, ) @@ -132,36 +129,6 @@ class S3ExpressIntegration : KotlinIntegration { } } - /** - * For any operations that require a checksum, set CRC32 if the user has not already configured a checksum. - */ - private val useCrc32Checksum = object : ProtocolMiddleware { - override val name: String = "UseCrc32Checksum" - - override val order: Byte = -1 // Render before flexible checksums - - override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = !op.isS3UploadPart && - (op.hasTrait() || (op.hasTrait() && op.expectTrait().isRequestChecksumRequired)) - - override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { - val interceptorSymbol = buildSymbol { - namespace = "aws.sdk.kotlin.services.s3.express" - name = "S3ExpressCrc32ChecksumInterceptor" - } - - val httpChecksumTrait = op.getTrait() - - val checksumAlgorithmMember = ctx.model.expectShape(op.input.get()) - .members() - .firstOrNull { it.memberName == httpChecksumTrait?.requestAlgorithmMember?.getOrNull() } - - // S3 models a header name x-amz-sdk-checksum-algorithm representing the name of the checksum algorithm used - val checksumHeaderName = checksumAlgorithmMember?.getTrait()?.value - - writer.write("op.interceptors.add(#T(${checksumHeaderName?.dq() ?: ""}))", interceptorSymbol) - } - } - /** * Disable all checksums for s3:UploadPart */ @@ -169,7 +136,7 @@ class S3ExpressIntegration : KotlinIntegration { override val name: String = "UploadPartDisableChecksum" override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = - op.isS3UploadPart + op.isS3UploadPart && op.hasTrait() override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { val interceptorSymbol = buildSymbol { diff --git a/services/s3/common/src/aws/sdk/kotlin/services/s3/express/S3ExpressDisableChecksumInterceptor.kt b/services/s3/common/src/aws/sdk/kotlin/services/s3/express/S3ExpressDisableChecksumInterceptor.kt index 3b10ad3fa69..e78cde99060 100644 --- a/services/s3/common/src/aws/sdk/kotlin/services/s3/express/S3ExpressDisableChecksumInterceptor.kt +++ b/services/s3/common/src/aws/sdk/kotlin/services/s3/express/S3ExpressDisableChecksumInterceptor.kt @@ -6,14 +6,18 @@ package aws.sdk.kotlin.services.s3.express import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext import aws.smithy.kotlin.runtime.collections.AttributeKey +import aws.smithy.kotlin.runtime.http.DeferredHeadersBuilder +import aws.smithy.kotlin.runtime.http.HeadersBuilder import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor -import aws.smithy.kotlin.runtime.http.operation.HttpOperationContext import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.http.request.toBuilder import aws.smithy.kotlin.runtime.telemetry.logging.logger import kotlin.coroutines.coroutineContext +private const val CHECKSUM_HEADER_PREFIX = "x-amz-checksum-" + /** - * Disable checksums entirely for s3:UploadPart requests. + * Disables checksums for s3:UploadPart requests that use S3 express. */ internal class S3ExpressDisableChecksumInterceptor : HttpInterceptor { override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext): HttpRequest { @@ -22,14 +26,45 @@ internal class S3ExpressDisableChecksumInterceptor : HttpInterceptor { } val logger = coroutineContext.logger() + logger.warn { "Checksums must not be sent with S3 express upload part operation, removing checksum(s)" } + + val request = context.protocolRequest.toBuilder() + + request.headers.removeChecksumHeaders() + request.trailingHeaders.removeChecksumTrailingHeaders() + request.headers.removeChecksumTrailingHeadersFromXAmzTrailer() + + return request.build() + } +} - val configuredChecksumAlgorithm = context.executionContext.getOrNull(HttpOperationContext.ChecksumAlgorithm) +/** + * Removes any checksums sent in the request's headers + */ +internal fun HeadersBuilder.removeChecksumHeaders(): Unit = + names().forEach { name -> + if (name.startsWith(CHECKSUM_HEADER_PREFIX)) { + remove(name) + } + } - configuredChecksumAlgorithm?.let { - logger.warn { "Disabling configured checksum $it for S3 Express UploadPart" } - context.executionContext.remove(HttpOperationContext.ChecksumAlgorithm) +/** + * Removes any checksums sent in the request's trailing headers + */ +internal fun DeferredHeadersBuilder.removeChecksumTrailingHeaders(): Unit = + names().forEach { name -> + if (name.startsWith(CHECKSUM_HEADER_PREFIX)) { + remove(name) } + } - return context.protocolRequest +/** + * Removes any checksums sent in the request's trailing headers from `x-amz-trailer` + */ +internal fun HeadersBuilder.removeChecksumTrailingHeadersFromXAmzTrailer() { + this.getAll("x-amz-trailer")?.forEach { trailingHeader -> + if (trailingHeader.startsWith(CHECKSUM_HEADER_PREFIX)) { + this.remove("x-amz-trailer", trailingHeader) + } } } diff --git a/services/s3/common/test/aws/sdk/kotlin/services/s3/express/ChecksumRemovalTest.kt b/services/s3/common/test/aws/sdk/kotlin/services/s3/express/ChecksumRemovalTest.kt new file mode 100644 index 00000000000..cbf5a9ad4d4 --- /dev/null +++ b/services/s3/common/test/aws/sdk/kotlin/services/s3/express/ChecksumRemovalTest.kt @@ -0,0 +1,83 @@ +package aws.sdk.kotlin.services.s3.express + +import aws.smithy.kotlin.runtime.http.DeferredHeadersBuilder +import aws.smithy.kotlin.runtime.http.HeadersBuilder +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ChecksumRemovalTest { + @Test + fun removeChecksumHeaders() { + val headers = HeadersBuilder() + + headers.append("x-amz-checksum-crc32", "foo") + headers.append("x-amz-checksum-sha256", "bar") + + assertTrue( + headers.contains("x-amz-checksum-crc32"), + ) + assertTrue( + headers.contains("x-amz-checksum-sha256"), + ) + + headers.removeChecksumHeaders() + + assertFalse( + headers.contains("x-amz-checksum-crc32"), + ) + assertFalse( + headers.contains("x-amz-checksum-sha256"), + ) + } + + @Test + fun removeChecksumTrailingHeaders() { + val trailingHeaders = DeferredHeadersBuilder() + + trailingHeaders.add("x-amz-checksum-crc32", "foo") + trailingHeaders.add("x-amz-checksum-sha256", "bar") + + assertTrue( + trailingHeaders.contains("x-amz-checksum-crc32"), + ) + assertTrue( + trailingHeaders.contains("x-amz-checksum-sha256"), + ) + + trailingHeaders.removeChecksumTrailingHeaders() + + assertFalse( + trailingHeaders.contains("x-amz-checksum-crc32"), + ) + assertFalse( + trailingHeaders.contains("x-amz-checksum-sha256"), + ) + } + + @Test + fun removeChecksumTrailingHeadersFromXAmzTrailer() { + val headers = HeadersBuilder() + + headers.append("x-amz-trailer", "x-amz-checksum-crc32") + headers.append("x-amz-trailer", "x-amz-trailing-header") + + val xAmzTrailer = headers.getAll("x-amz-trailer") + + assertTrue( + xAmzTrailer?.contains("x-amz-checksum-crc32") ?: false, + ) + assertTrue( + xAmzTrailer?.contains("x-amz-trailing-header") ?: false, + ) + + headers.removeChecksumTrailingHeadersFromXAmzTrailer() + + assertFalse( + xAmzTrailer?.contains("x-amz-checksum-crc32") ?: false, + ) + assertTrue( + xAmzTrailer?.contains("x-amz-trailing-header") ?: false, + ) + } +}