From 0b8b0b1d99ea96e315b28d7ea9a54bd902d552ae Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 2 Oct 2023 13:08:19 -0400 Subject: [PATCH] refactor http label and query validation to happen during operation serialization --- .../smithy/kotlin/codegen/model/SymbolExt.kt | 2 +- .../codegen/rendering/StructureGenerator.kt | 22 +----- .../protocol/HttpBindingProtocolGenerator.kt | 36 +++++++++- .../rendering/StructureGeneratorTest.kt | 60 ---------------- .../HttpBindingProtocolGeneratorTest.kt | 72 ++++++++++++++++++- 5 files changed, 107 insertions(+), 85 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/SymbolExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/SymbolExt.kt index 037c3cf273..0f39d91095 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/SymbolExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/SymbolExt.kt @@ -59,7 +59,7 @@ object SymbolProperty { * set default so the constructor will have to generate a runtime check that a value is set. */ val Symbol.isRequiredWithNoDefault: Boolean - get() = !isNullable && defaultValue() == null + get() = isNotNullable && defaultValue() == null /** * Test if a symbol is nullable diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGenerator.kt index e6941ddd23..c5a88601ca 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGenerator.kt @@ -10,7 +10,6 @@ import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.rendering.serde.ClientErrorCorrection -import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.model.shapes.* import software.amazon.smithy.model.traits.* @@ -85,7 +84,7 @@ class StructureGenerator( "public" } - if (memberShape.isRequiredInStruct && memberSymbol.isRequiredWithNoDefault) { + if (memberSymbol.isRequiredWithNoDefault) { writer.write( """#1L val #2L: #3F = requireNotNull(builder.#2L) { "A non-null value must be provided for #2L" }""", prefix, @@ -95,27 +94,8 @@ class StructureGenerator( } else { writer.write("#1L val #2L: #3F = builder.#2L", prefix, memberName, memberSymbol) } - - if (memberShape.isNonBlankInStruct) { - writer - .indent() - .write( - """.apply { require(isNotBlank()) { "A non-blank value must be provided for #L" } }""", - memberName, - ) - .dedent() - } } - private val MemberShape.isRequiredInStruct - get() = - hasTrait() || isRequired - - private val MemberShape.isNonBlankInStruct: Boolean - get() = - ctx.model.expectShape(target).isStringShape && - getTrait()?.min?.getOrNull()?.takeIf { it > 0 } != null - private fun renderCompanionObject() { writer.withBlock("public companion object {", "}") { write("public operator fun invoke(block: Builder.() -> #Q): #Q = Builder().apply(block).build()", KotlinTypes.Unit, symbol) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt index 9e0881be07..2ac8bac37b 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.lang.toEscapedLiteral import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.rendering.serde.* +import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.shapes.* @@ -287,6 +288,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val pathBindings = requestBindings.filter { it.location == HttpBinding.Location.LABEL } if (pathBindings.isNotEmpty()) { + // One of the few places we generate client side validation + // httpLabel bindings must be non-null + httpTrait.uri.segments.filter { it.isLabel || it.isGreedyLabel }.forEach { segment -> + val binding = pathBindings.find { + it.memberName == segment.content + } ?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}`") + + val memberSymbol = ctx.symbolProvider.toSymbol(binding.member) + if (memberSymbol.isNullable) { + writer.write("""requireNotNull(input.#1L) { "#1L is bound to the URI and must not be null" }""", binding.member.defaultName()) + } + + // check length trait if applicable + renderNonBlankGuard(ctx, binding.member, writer) + } + writer.openBlock("val pathSegments = listOf(", ")") { httpTrait.uri.segments.forEach { segment -> if (segment.isLabel || segment.isGreedyLabel) { @@ -297,6 +314,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { // shape must be string, number, boolean, or timestamp val targetShape = ctx.model.expectShape(binding.member.target) + val memberSymbol = ctx.symbolProvider.toSymbol(binding.member) val identifier = if (targetShape.isTimestampShape) { writer.addImport(RuntimeTypes.Core.TimestampFormat) val tsFormat = resolver.determineTimestampFormat( @@ -304,14 +322,14 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { HttpBinding.Location.LABEL, defaultTimestampFormat, ) - val tsLabel = formatInstant("input.${binding.member.defaultName()}?", tsFormat, forceString = true) + val nullCheck = if (memberSymbol.isNullable) "?" else "" + val tsLabel = formatInstant("input.${binding.member.defaultName()}$nullCheck", tsFormat, forceString = true) tsLabel } else { "input.${binding.member.defaultName()}" } val encodeSymbol = RuntimeTypes.Http.Util.encodeLabel - writer.addImport(encodeSymbol) val encodeFn = if (segment.isGreedyLabel) { writer.format("#T(greedy = true)", encodeSymbol) } else { @@ -362,6 +380,9 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { writer.write("append(#S, #S)", key, value) } + // render length check if applicable + queryBindings.forEach { binding -> renderNonBlankGuard(ctx, binding.member, writer) } + renderStringValuesMapParameters(ctx, queryBindings, writer) queryMapBindings.forEach { @@ -1001,3 +1022,14 @@ fun OperationShape.errorHandler(settings: KotlinSettings, block: SymbolRenderer) definitionFile = "${deserializerName()}.kt" renderBy = block } + +private fun renderNonBlankGuard(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: KotlinWriter) { + if (member.isNonBlankInStruct(ctx)) { + val memberSymbol = ctx.symbolProvider.toSymbol(member) + val nullCheck = if (memberSymbol.isNullable) "?" else "" + writer.write("""require(input.#1L$nullCheck.isNotBlank() == true) { "#1L is bound to the URI and must be a non-blank value" }""", member.defaultName()) + } +} +private fun MemberShape.isNonBlankInStruct(ctx: ProtocolGenerator.GenerationContext): Boolean = + ctx.model.expectShape(target).isStringShape && + getTrait()?.min?.getOrNull()?.takeIf { it > 0 } != null diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGeneratorTest.kt index 042c8907b5..e12a6686a4 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGeneratorTest.kt @@ -361,66 +361,6 @@ class StructureGeneratorTest { generated2.shouldContainOnlyOnceWithDiff(expected2) } - @Test - fun `it handles required HTTP fields in initializers`() { - val model = """ - @http(method: "POST", uri: "/foo/{bar}/{baz}") - operation Foo { - input: FooRequest - } - - structure FooRequest { - @required - @httpLabel - bar: String, - - @httpLabel - @required - baz: Integer, - - @httpPayload - qux: String, - - @required - @httpQuery("quux") - quux: Boolean, - - @httpQuery("corge") - corge: String, - - @required - @length(min: 0) - @httpQuery("grault") - grault: String, - - @required - @length(min: 3) - @httpQuery("garply") - garply: String - } - """.prependNamespaceAndService(operations = listOf("Foo")).toSmithyModel() - - val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model) - val writer = KotlinWriter(TestModelDefault.NAMESPACE) - val struct = model.expectShape("com.test#FooRequest") - val renderingCtx = RenderingContext(writer, struct, model, provider, model.defaultSettings()) - StructureGenerator(renderingCtx).render() - - val generated = writer.toString() - val expected = """ - public class FooRequest private constructor(builder: Builder) { - public val bar: kotlin.String = requireNotNull(builder.bar) { "A non-null value must be provided for bar" } - public val baz: kotlin.Int = requireNotNull(builder.baz) { "A non-null value must be provided for baz" } - public val corge: kotlin.String? = builder.corge - public val garply: kotlin.String = requireNotNull(builder.garply) { "A non-null value must be provided for garply" } - .apply { require(isNotBlank()) { "A non-blank value must be provided for garply" } } - public val grault: kotlin.String = requireNotNull(builder.grault) { "A non-null value must be provided for grault" } - public val quux: kotlin.Boolean = requireNotNull(builder.quux) { "A non-null value must be provided for quux" } - public val qux: kotlin.String? = builder.qux - """.formatForTest(indent = "") - generated.shouldContainOnlyOnceWithDiff(expected) - } - @Test fun `it handles required query params in initializers`() { val model = """ diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGeneratorTest.kt index 47ca72cf66..da3b74284f 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGeneratorTest.kt @@ -248,7 +248,7 @@ internal class EnumInputOperationSerializer: HttpSerialize { fun itSerializesOperationInputsWithTimestamps() { val contents = getTransformFileContents("TimestampInputOperationSerializer.kt") contents.assertBalancedBracesAndParens() - val tsLabel = "\${input.tsLabel?.format(TimestampFormat.ISO_8601)}" // workaround for raw strings not being able to contain escapes + val tsLabel = "\${input.tsLabel.format(TimestampFormat.ISO_8601)}" // workaround for raw strings not being able to contain escapes val expectedContents = """ internal class TimestampInputOperationSerializer: HttpSerialize { override suspend fun serialize(context: ExecutionContext, input: TimestampInputRequest): HttpRequestBuilder { @@ -510,4 +510,74 @@ internal class SmokeTestOperationDeserializer: HttpDeserialize( + "foo", + "$label1".encodeLabel(), + "$label2".encodeLabel(), + ) + path = pathSegments.joinToString(separator = "/", prefix = "/") + parameters { + require(input.garply?.isNotBlank() == true) { "garply is bound to the URI and must be a non-blank value" } + if (input.corge != null) append("corge", input.corge) + if (input.garply != null) append("garply", input.garply) + if (input.grault != null) append("grault", input.grault) + if (input.quux != null) append("quux", "$quux") + } + """ + + contents.shouldContainOnlyOnceWithDiff(expected) + } }