Skip to content

Commit

Permalink
Make it compile
Browse files Browse the repository at this point in the history
  • Loading branch information
0marperez committed Dec 19, 2024
1 parent 1157f26 commit d9f0659
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private val httpChecksumRequiredMiddleware = object : ProtocolMiddleware {
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
writer.write(
"op.interceptors.add(#T())",
RuntimeTypes.HttpClient.Interceptors.HttpChecksumRequiredInterceptor
RuntimeTypes.HttpClient.Interceptors.HttpChecksumRequiredInterceptor,
)
}
}
14 changes: 4 additions & 10 deletions runtime/protocol/http-client/api/http-client.api
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,16 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpoin
}

public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor : aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor {
public fun <init> ()V
public fun <init> (Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (ZLaws/smithy/kotlin/runtime/client/config/HttpChecksumConfigOption;Ljava/lang/String;)V
public fun applyChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Ljava/lang/String;)Laws/smithy/kotlin/runtime/http/request/HttpRequest;
public fun calculateChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun readAfterSerialization (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V
}

public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor : aws/smithy/kotlin/runtime/client/Interceptor {
public class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor : aws/smithy/kotlin/runtime/client/Interceptor {
public static final field Companion Laws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor$Companion;
public fun <init> (Lkotlin/jvm/functions/Function1;)V
public fun <init> (ZLaws/smithy/kotlin/runtime/client/config/HttpChecksumConfigOption;)V
public fun ignoreChecksum (Ljava/lang/String;)Z
public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand All @@ -371,10 +367,8 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksums
public final fun getChecksumHeaderValidated ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
}

public final class aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor : aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor {
public final class aws/smithy/kotlin/runtime/http/interceptors/HttpChecksumRequiredInterceptor : aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor {
public fun <init> ()V
public fun <init> (Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun applyChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Ljava/lang/String;)Laws/smithy/kotlin/runtime/http/request/HttpRequest;
public fun calculateChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down Expand Up @@ -518,9 +512,9 @@ public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpDes

public final class aws/smithy/kotlin/runtime/http/operation/HttpOperationContext {
public static final field INSTANCE Laws/smithy/kotlin/runtime/http/operation/HttpOperationContext;
public final fun getChecksumAlgorithm ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getClockSkew ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getClockSkewApproximateSigningTime ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getDefaultChecksumAlgorithm ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getHostPrefix ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getHttpCallList ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getOperationAttributes ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ public class FlexibleChecksumsRequestInterceptor(

context.protocolRequest.userProvidedChecksumHeader(logger)?.let {
logger.debug { "Checksum was supplied via header: skipping checksum calculation" }

val request = context.protocolRequest.toBuilder()
request.headers.removeAllChecksumHeadersExcept(it)
return context.protocolRequest
}

checksumAlgorithm(
requestChecksumRequired,
requestChecksumCalculation,
requestChecksumAlgorithm,
context
context,
)?.let { checksumAlgorithm ->
if (context.protocolRequest.body.isEligibleForAwsChunkedStreaming) { // Handle checksum calculation here
logger.debug { "Calculating checksum during transmission using: ${checksumAlgorithm::class.simpleName}" }
Expand Down Expand Up @@ -97,7 +100,7 @@ public class FlexibleChecksumsRequestInterceptor(
requestChecksumRequired: Boolean,
requestChecksumCalculation: HttpChecksumConfigOption?,
requestChecksumAlgorithm: String?,
context: ProtocolRequestInterceptorContext<Any, HttpRequest>
context: ProtocolRequestInterceptorContext<Any, HttpRequest>,
): HashFunction? =
requestChecksumAlgorithm
?.toHashFunctionOrThrow()
Expand All @@ -116,7 +119,7 @@ public class FlexibleChecksumsRequestInterceptor(
requestChecksumRequired,
requestChecksumCalculation,
requestChecksumAlgorithm,
context
context,
)!!

return when {
Expand All @@ -135,18 +138,19 @@ public class FlexibleChecksumsRequestInterceptor(
// Handles applying checksum for non-aws-chunked requests
override fun applyChecksum(
context: ProtocolRequestInterceptorContext<Any, HttpRequest>,
checksum: String
checksum: String,
): HttpRequest {
val request = context.protocolRequest.toBuilder()
val checksumAlgorithm = checksumAlgorithm(
requestChecksumRequired,
requestChecksumCalculation,
requestChecksumAlgorithm,
context
context,
)!!
val checksumHeader = checksumAlgorithmHeader(checksumAlgorithm)

request.headers[checksumHeader] = checksum
request.headers.removeAllChecksumHeadersExcept(checksumHeader)
context.executionContext.emitBusinessMetric(checksumAlgorithm.toBusinessMetric())

return request.build()
Expand Down Expand Up @@ -197,4 +201,12 @@ public class FlexibleChecksumsRequestInterceptor(
is Sha256 -> SmithyBusinessMetric.FLEXIBLE_CHECKSUMS_REQ_SHA256
else -> throw IllegalStateException("Checksum was calculated using an unsupported hash function: ${this::class.simpleName}")
}

/**
* Removes all checksum headers except specified header
*/
private fun HeadersBuilder.removeAllChecksumHeadersExcept(checksumHeader: String) =
names()
.filter { it.startsWith("x-amz-checksum-", ignoreCase = true) && !it.equals(checksumHeader, ignoreCase = true) }
.forEach { remove(it) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public open class FlexibleChecksumsResponseInterceptor(

val checksumHeader = CHECKSUM_HEADER_VALIDATION_PRIORITY_LIST
.firstOrNull { context.protocolResponse.headers.contains(it) } ?: run {
logger.warn { "Checksum validation was requested but the response headers didn't contain a valid checksum." }
return context.protocolResponse
}
logger.warn { "Checksum validation was requested but the response headers didn't contain a valid checksum." }
return context.protocolResponse
}

val serviceChecksumValue = context.protocolResponse.headers[checksumHeader]!!
if (ignoreChecksum(serviceChecksumValue)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,3 @@ public class HttpChecksumRequiredInterceptor : AbstractChecksumInterceptor() {
return request.build()
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FlexibleChecksumsRequestInterceptorTest {
val op = newTestOperation<Unit, Unit>(req, Unit)

op.interceptors.add(
FlexibleChecksumsRequestInterceptor<Unit>(
FlexibleChecksumsRequestInterceptor(
requestChecksumAlgorithm = checksumAlgorithmName,
requestChecksumRequired = true,
requestChecksumCalculation = HttpChecksumConfigOption.WHEN_SUPPORTED,
Expand All @@ -68,8 +68,9 @@ class FlexibleChecksumsRequestInterceptorTest {

val op = newTestOperation<Unit, Unit>(req, Unit)

op.context.attributes[HttpOperationContext.DefaultChecksumAlgorithm] = "CRC32"
op.interceptors.add(
FlexibleChecksumsRequestInterceptor<Unit>(
FlexibleChecksumsRequestInterceptor(
requestChecksumAlgorithm = checksumAlgorithmName,
requestChecksumRequired = true,
requestChecksumCalculation = HttpChecksumConfigOption.WHEN_SUPPORTED,
Expand All @@ -94,12 +95,13 @@ class FlexibleChecksumsRequestInterceptorTest {

assertFailsWith<ClientException> {
op.interceptors.add(
FlexibleChecksumsRequestInterceptor<Unit>(
FlexibleChecksumsRequestInterceptor(
requestChecksumAlgorithm = unsupportedChecksumAlgorithmName,
requestChecksumRequired = true,
requestChecksumCalculation = HttpChecksumConfigOption.WHEN_SUPPORTED,
),
)
op.roundTrip(client, Unit)
}
}

Expand All @@ -120,7 +122,7 @@ class FlexibleChecksumsRequestInterceptorTest {
val op = newTestOperation<Unit, Unit>(req, Unit)

op.interceptors.add(
FlexibleChecksumsRequestInterceptor<Unit>(
FlexibleChecksumsRequestInterceptor(
requestChecksumAlgorithm = checksumAlgorithmName,
requestChecksumRequired = true,
requestChecksumCalculation = HttpChecksumConfigOption.WHEN_SUPPORTED,
Expand All @@ -141,7 +143,7 @@ class FlexibleChecksumsRequestInterceptorTest {
val source = byteArray.source()
val completableDeferred = CompletableDeferred<String>()
val hashingSource = HashingSource(hashFunctionName.toHashFunction()!!, source)
val completingSource = FlexibleChecksumsRequestInterceptor.CompletingSource(completableDeferred, hashingSource)
val completingSource = CompletingSource(completableDeferred, hashingSource)

completingSource.read(SdkBuffer(), 1L)
assertFalse(completableDeferred.isCompleted) // deferred value should not be completed because the source is not exhausted
Expand All @@ -162,7 +164,7 @@ class FlexibleChecksumsRequestInterceptorTest {
val channel = SdkByteReadChannel(byteArray)
val completableDeferred = CompletableDeferred<String>()
val hashingChannel = HashingByteReadChannel(hashFunctionName.toHashFunction()!!, channel)
val completingChannel = FlexibleChecksumsRequestInterceptor.CompletingByteReadChannel(completableDeferred, hashingChannel)
val completingChannel = CompletingByteReadChannel(completableDeferred, hashingChannel)

completingChannel.read(SdkBuffer(), 1L)
assertFalse(completableDeferred.isCompleted)
Expand All @@ -188,7 +190,7 @@ class FlexibleChecksumsRequestInterceptorTest {
val op = newTestOperation<Unit, Unit>(req, Unit)

op.interceptors.add(
FlexibleChecksumsRequestInterceptor<Unit>(
FlexibleChecksumsRequestInterceptor(
requestChecksumAlgorithm = checksumAlgorithmName,
requestChecksumRequired = true,
requestChecksumCalculation = HttpChecksumConfigOption.WHEN_SUPPORTED,
Expand Down Expand Up @@ -232,8 +234,9 @@ class FlexibleChecksumsRequestInterceptorTest {

val op = newTestOperation<Unit, Unit>(req, Unit)

op.context.attributes[HttpOperationContext.DefaultChecksumAlgorithm] = "CRC32"
op.interceptors.add(
FlexibleChecksumsRequestInterceptor<Unit>(
FlexibleChecksumsRequestInterceptor(
requestChecksumAlgorithm = null, // See if default checksum is applied
requestChecksumRequired = testCase.requestChecksumRequired,
requestChecksumCalculation = testCase.requestChecksumCalculation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class FlexibleChecksumsResponseInterceptorTest {
val op = newTestOperation<TestInput>(req)

op.interceptors.add(
FlexibleChecksumsResponseInterceptor<TestInput>(
FlexibleChecksumsResponseInterceptor(
responseValidationRequired = true,
responseChecksumValidation = HttpChecksumConfigOption.WHEN_SUPPORTED,
),
Expand All @@ -101,7 +101,7 @@ class FlexibleChecksumsResponseInterceptorTest {
val op = newTestOperation<TestInput>(req)

op.interceptors.add(
FlexibleChecksumsResponseInterceptor<TestInput>(
FlexibleChecksumsResponseInterceptor(
responseValidationRequired = true,
responseChecksumValidation = HttpChecksumConfigOption.WHEN_SUPPORTED,
),
Expand Down Expand Up @@ -129,7 +129,7 @@ class FlexibleChecksumsResponseInterceptorTest {
val op = newTestOperation<TestInput>(req)

op.interceptors.add(
FlexibleChecksumsResponseInterceptor<TestInput>(
FlexibleChecksumsResponseInterceptor(
responseValidationRequired = true,
responseChecksumValidation = HttpChecksumConfigOption.WHEN_SUPPORTED,
),
Expand All @@ -154,7 +154,7 @@ class FlexibleChecksumsResponseInterceptorTest {
val op = newTestOperation<TestInput>(req)

op.interceptors.add(
FlexibleChecksumsResponseInterceptor<TestInput>(
FlexibleChecksumsResponseInterceptor(
responseValidationRequired = true,
responseChecksumValidation = HttpChecksumConfigOption.WHEN_SUPPORTED,
),
Expand All @@ -175,7 +175,7 @@ class FlexibleChecksumsResponseInterceptorTest {
val op = newTestOperation<TestInput>(req)

op.interceptors.add(
FlexibleChecksumsResponseInterceptor<TestInput>(
FlexibleChecksumsResponseInterceptor(
responseValidationRequired = false,
responseChecksumValidation = HttpChecksumConfigOption.WHEN_REQUIRED,
),
Expand Down Expand Up @@ -219,7 +219,7 @@ class FlexibleChecksumsResponseInterceptorTest {
val op = newTestOperation<TestInput>(req)

op.interceptors.add(
FlexibleChecksumsResponseInterceptor<TestInput>(
FlexibleChecksumsResponseInterceptor(
responseValidationRequired = testCase.responseValidationRequired,
responseChecksumValidation = testCase.responseChecksumValidation,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package aws.smithy.kotlin.runtime.http.interceptors

import aws.smithy.kotlin.runtime.collections.get
import aws.smithy.kotlin.runtime.hashing.Crc32
import aws.smithy.kotlin.runtime.http.HttpBody
import aws.smithy.kotlin.runtime.http.SdkHttpClient
import aws.smithy.kotlin.runtime.http.operation.HttpOperationContext
Expand All @@ -14,6 +15,7 @@ import aws.smithy.kotlin.runtime.http.operation.roundTrip
import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder
import aws.smithy.kotlin.runtime.httptest.TestEngine
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
Expand All @@ -29,10 +31,9 @@ class HttpChecksumRequiredInterceptorTest {
}
val op = newTestOperation<Unit, Unit>(req, Unit)

op.context.attributes[HttpOperationContext.DefaultChecksumAlgorithm] = "MD5"
op.interceptors.add(
HttpChecksumRequiredInterceptor<Unit> {
true
},
HttpChecksumRequiredInterceptor(),
)

val expected = "RG22oBSZFmabBbkzVGRi4w=="
Expand All @@ -41,6 +42,29 @@ class HttpChecksumRequiredInterceptorTest {
assertEquals(expected, call.request.headers["Content-MD5"])
}

@Test
fun itSetsContentCrc32Header() = runTest {
val testBody = "<Foo>bar</Foo>".encodeToByteArray()

val req = HttpRequestBuilder().apply {
body = HttpBody.fromBytes(testBody)
}
val op = newTestOperation<Unit, Unit>(req, Unit)

op.context.attributes[HttpOperationContext.DefaultChecksumAlgorithm] = "CRC32"
op.interceptors.add(
HttpChecksumRequiredInterceptor(),
)

val crc32 = Crc32()
crc32.update(testBody)
val expected = crc32.digest().encodeBase64String()

op.roundTrip(client, Unit)
val call = op.context.attributes[HttpOperationContext.HttpCallList].first()
assertEquals(expected, call.request.headers["x-amz-checksum-crc32"])
}

@Test
fun itOnlySetsHeaderForBytesContent() = runTest {
val req = HttpRequestBuilder().apply {
Expand All @@ -50,10 +74,9 @@ class HttpChecksumRequiredInterceptorTest {
}
val op = newTestOperation<Unit, Unit>(req, Unit)

op.context.attributes[HttpOperationContext.DefaultChecksumAlgorithm] = "MD5"
op.interceptors.add(
HttpChecksumRequiredInterceptor<Unit> {
true
},
HttpChecksumRequiredInterceptor(),
)

op.roundTrip(client, Unit)
Expand All @@ -69,9 +92,7 @@ class HttpChecksumRequiredInterceptorTest {
val op = newTestOperation<Unit, Unit>(req, Unit)

op.interceptors.add(
HttpChecksumRequiredInterceptor<Unit> {
false // interceptor disabled
},
HttpChecksumRequiredInterceptor(),
)

op.roundTrip(client, Unit)
Expand Down
Loading

0 comments on commit d9f0659

Please sign in to comment.