Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Customize S3's Expires field #1287

Merged
merged 15 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/14eca09d-0ee6-45f3-a758-5e5b2ac471f0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "14eca09d-0ee6-45f3-a758-5e5b2ac471f0",
"type": "feature",
"description": "Customize S3's `Expires` field, including adding a new `ExpiresString` field for output types."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.sdk.kotlin.codegen.customization.s3

import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.DeprecatedTrait
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.model.traits.HttpHeaderTrait
import software.amazon.smithy.model.traits.OutputTrait
import software.amazon.smithy.model.transform.ModelTransformer
import kotlin.streams.asSequence

/**
* An integration used to customize behavior around S3's members named `Expires`.
*/
class S3ExpiresIntegration : KotlinIntegration {
override fun enabledForService(model: Model, settings: KotlinSettings) =
model.expectShape<ServiceShape>(settings.service).isS3 && model.shapes<OperationShape>().any { it.hasExpiresMember(model) }

override fun preprocessModel(model: Model, settings: KotlinSettings): Model {
val transformer = ModelTransformer.create()

// Ensure all `Expires` shapes are timestamps
val expiresShapeTimestampMap = model.shapes()
.asSequence()
.mapNotNull { shape ->
shape.members()
.singleOrNull { member -> member.memberName.lowercase() == "expires" }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Prefer .equals("expires", ignoreCase = true) here and elsewhere

?.target
}
.associateWith { ShapeType.TIMESTAMP }

var transformedModel = transformer.changeShapeType(model, expiresShapeTimestampMap)

// Add an `ExpiresString` string shape to the model
val expiresString = StringShape.builder()
expiresString.id("aws.sdk.kotlin.s3.synthetic#ExpiresString")
transformedModel = transformedModel.toBuilder().addShape(expiresString.build()).build()

// For output shapes only, deprecate `Expires` and add a synthetic member that targets `ExpiresString`
return transformer.mapShapes(transformedModel) { shape ->
if (shape.hasTrait<OutputTrait>() && shape.memberNames.any { it.lowercase() == "expires" }) {
val builder = (shape as StructureShape).toBuilder()

// Deprecate `Expires`
val expiresMember = shape.members().single { it.memberName.lowercase() == "expires" }

builder.removeMember(expiresMember.memberName)
val deprecatedTrait = DeprecatedTrait.builder()
.message("Please use `expiresString` which contains the raw, unparsed value of this field.")
.since("2024-04-16")
.build()

builder.addMember(
expiresMember.toBuilder()
.addTrait(deprecatedTrait)
.build(),
)

// Add a synthetic member targeting `ExpiresString`
val expiresStringMember = MemberShape.builder()
expiresStringMember.target(expiresString.id)
expiresStringMember.id(expiresMember.id.toString() + "String") // i.e. com.amazonaws.s3.<MEMBER_NAME>$ExpiresString
expiresStringMember.addTrait(HttpHeaderTrait("ExpiresString")) // Add HttpHeaderTrait to ensure the field is deserialized
expiresMember.getTrait<DocumentationTrait>()?.let {
expiresStringMember.addTrait(it) // Copy documentation from `Expires`
}
builder.addMember(expiresStringMember.build())
builder.build()
} else {
shape
}
}
}

override fun customizeMiddleware(
ctx: ProtocolGenerator.GenerationContext,
resolved: List<ProtocolMiddleware>,
): List<ProtocolMiddleware> = resolved + applyExpiresFieldInterceptor

internal val applyExpiresFieldInterceptor = object : ProtocolMiddleware {
override val name: String = "ExpiresFieldInterceptor"

override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean =
ctx.model.expectShape<ServiceShape>(ctx.settings.service).isS3 && op.hasExpiresMember(ctx.model)

override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
val interceptorSymbol = buildSymbol {
name = "ExpiresFieldInterceptor"
namespace = ctx.settings.pkg.subpackage("internal")
}

writer.write("op.interceptors.add(#T)", interceptorSymbol)
}
}

private fun OperationShape.hasExpiresMember(model: Model): Boolean {
val input = model.expectShape(this.inputShape)
val output = model.expectShape(this.outputShape)

return (input.memberNames + output.memberNames).any {
it.lowercase() == "expires"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ aws.sdk.kotlin.codegen.customization.SigV4AsymmetricTraitCustomization
aws.sdk.kotlin.codegen.customization.cloudfrontkeyvaluestore.BackfillSigV4ACustomization
aws.sdk.kotlin.codegen.customization.s3.express.SigV4S3ExpressAuthSchemeIntegration
aws.sdk.kotlin.codegen.customization.s3.express.S3ExpressIntegration
aws.sdk.kotlin.codegen.customization.s3.S3ExpiresIntegration
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.sdk.kotlin.codegen.customization.s3

import aws.sdk.kotlin.codegen.testutil.model
import software.amazon.smithy.kotlin.codegen.model.isDeprecated
import software.amazon.smithy.kotlin.codegen.model.targetOrSelf
import software.amazon.smithy.kotlin.codegen.test.defaultSettings
import software.amazon.smithy.kotlin.codegen.test.newTestContext
import software.amazon.smithy.kotlin.codegen.test.toSmithyModel
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.shapes.ShapeId
import kotlin.test.*
import kotlin.test.assertTrue

class S3ExpiresIntegrationTest {
private val testModel = """
namespace smithy.example

use aws.protocols#restXml
use aws.auth#sigv4
use aws.api#service

@restXml
@sigv4(name: "s3")
@service(
sdkId: "S3"
arnNamespace: "s3"
)
service S3 {
version: "1.0.0",
operations: [GetFoo, NewGetFoo]
}

operation GetFoo {
input: GetFooInput
output: GetFooOutput
}

operation NewGetFoo {
input: GetFooInput
output: NewGetFooOutput
}

structure GetFooInput {
payload: String
expires: String
}

@output
structure GetFooOutput {
expires: Timestamp
}

@output
structure NewGetFooOutput {
expires: String
}
""".toSmithyModel()

@Test
fun testEnabledForS3() {
val enabled = S3ExpiresIntegration().enabledForService(testModel, testModel.defaultSettings())
assertTrue(enabled)
}

@Test
fun testDisabledForNonS3Model() {
val model = model("NotS3")
val enabled = S3ExpiresIntegration().enabledForService(model, model.defaultSettings())
assertFalse(enabled)
}

@Test
fun testMiddlewareAddition() {
val model = model("S3")
val preexistingMiddleware = listOf(FooMiddleware)
val ctx = model.newTestContext("S3")

val integration = S3ExpiresIntegration()
val actual = integration.customizeMiddleware(ctx.generationCtx, preexistingMiddleware)

assertEquals(listOf(FooMiddleware, integration.applyExpiresFieldInterceptor), actual)
}

@Test
fun testPreprocessModel() {
val integration = S3ExpiresIntegration()
val model = integration.preprocessModel(testModel, testModel.defaultSettings())

val expiresShapes = listOf(
model.expectShape(ShapeId.from("smithy.example#GetFooInput\$expires")),
model.expectShape(ShapeId.from("smithy.example#GetFooOutput\$expires")),
model.expectShape(ShapeId.from("smithy.example#NewGetFooOutput\$expires")),
)
// `Expires` members should always be Timestamp, even if its modeled as a string
assertTrue(expiresShapes.all { it.targetOrSelf(model).isTimestampShape })

// All `Expires` output members should be deprecated
assertTrue(
expiresShapes
.filter { it.id.toString().contains("Output") }
.all { it.isDeprecated },
)

val expiresStringFields = listOf(
model.expectShape(ShapeId.from("smithy.example#GetFooOutput\$expiresString")),
model.expectShape(ShapeId.from("smithy.example#NewGetFooOutput\$expiresString")),
)
// There should be no `ExpiresString` member added to the input shape
assertNull(model.getShape(ShapeId.from("smithy.example#GetFooInput\$expiresString")).getOrNull())

// There should be a synthetic `ExpiresString` string member added to output shapes
assertTrue(expiresStringFields.all { it.targetOrSelf(model).isStringShape })

// The synthetic fields should NOT be deprecated
assertTrue(expiresStringFields.none { it.isDeprecated })
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.sdk.kotlin.services.s3.internal

import aws.smithy.kotlin.runtime.client.ProtocolResponseInterceptorContext
import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
import aws.smithy.kotlin.runtime.http.request.HttpRequest
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.http.response.toBuilder
import aws.smithy.kotlin.runtime.telemetry.logging.logger
import aws.smithy.kotlin.runtime.time.Instant
import kotlin.coroutines.coroutineContext

/**
* Interceptor to handle special-cased `Expires` field which must not cause deserialization to fail.
*/
internal object ExpiresFieldInterceptor : HttpInterceptor {
override suspend fun modifyBeforeDeserialization(context: ProtocolResponseInterceptorContext<Any, HttpRequest, HttpResponse>): HttpResponse {
val response = context.protocolResponse.toBuilder()

if (response.headers.contains("Expires")) {
response.headers["ExpiresString"] = response.headers["Expires"]!!

// if parsing `Expires` would fail, remove the header value so it deserializes to `null`
try {
Instant.fromRfc5322(response.headers["Expires"]!!)
} catch (e: Exception) {
coroutineContext.logger<ExpiresFieldInterceptor>().warn {
"Failed to parse `expires`=${response.headers["Expires"]} as a timestamp, setting it to `null`. The unparsed value is available in `expiresString`."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Include quotes around the value which failed to parse. This makes values which are empty or contain multiple spaces easier to understand.

}
response.headers.remove("Expires")
}
}

return response.build()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.sdk.kotlin.services.s3.internal

import aws.sdk.kotlin.runtime.auth.credentials.StaticCredentialsProvider
import aws.sdk.kotlin.services.s3.S3Client
import aws.sdk.kotlin.services.s3.model.GetObjectRequest
import aws.smithy.kotlin.runtime.http.Headers
import aws.smithy.kotlin.runtime.http.HeadersBuilder
import aws.smithy.kotlin.runtime.http.HttpBody
import aws.smithy.kotlin.runtime.http.HttpStatusCode
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.httptest.buildTestConnection
import aws.smithy.kotlin.runtime.time.Instant
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull

class ExpiresFieldInterceptorTest {
private fun newTestClient(
status: HttpStatusCode = HttpStatusCode.OK,
headers: Headers = Headers.Empty,
): S3Client =
S3Client {
region = "us-east-1"
credentialsProvider = StaticCredentialsProvider {
accessKeyId = "accessKeyId"
secretAccessKey = "secretAccessKey"
}
httpClient = buildTestConnection {
expect(HttpResponse(status, headers, body = HttpBody.Empty))
}
}

@Test
fun testHandlesParsableExpiresField() = runTest {
val expectedHeaders = HeadersBuilder().apply {
append("Expires", "Mon, 1 Apr 2024 00:00:00 +0000")
}.build()

val s3 = newTestClient(headers = expectedHeaders)
s3.getObject(
GetObjectRequest {
bucket = "test"
key = "key"
},
) {
assertEquals(Instant.fromEpochSeconds(1711929600), it.expires)
assertEquals("Mon, 1 Apr 2024 00:00:00 +0000", it.expiresString)
}
}

@Test
fun testHandlesUnparsableExpiresField() = runTest {
val invalidExpiresField = "Tomorrow or maybe the day after?"

val expectedHeaders = HeadersBuilder().apply {
append("Expires", invalidExpiresField)
}.build()

val s3 = newTestClient(headers = expectedHeaders)
s3.getObject(
GetObjectRequest {
bucket = "test"
key = "key"
},
) {
assertNull(it.expires)
assertEquals(invalidExpiresField, it.expiresString)
}
}
}
Loading