Skip to content

Commit

Permalink
refactor http label and query validation to happen during operation s…
Browse files Browse the repository at this point in the history
…erialization
  • Loading branch information
aajtodd committed Oct 2, 2023
1 parent aeeed62 commit 0b8b0b1
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object SymbolProperty {
* set default so the constructor will have to generate a runtime check that a value is set.
*/
val Symbol.isRequiredWithNoDefault: Boolean
get() = !isNullable && defaultValue() == null
get() = isNotNullable && defaultValue() == null

/**
* Test if a symbol is nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.ClientErrorCorrection
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.*

Expand Down Expand Up @@ -85,7 +84,7 @@ class StructureGenerator(
"public"
}

if (memberShape.isRequiredInStruct && memberSymbol.isRequiredWithNoDefault) {
if (memberSymbol.isRequiredWithNoDefault) {
writer.write(
"""#1L val #2L: #3F = requireNotNull(builder.#2L) { "A non-null value must be provided for #2L" }""",
prefix,
Expand All @@ -95,27 +94,8 @@ class StructureGenerator(
} else {
writer.write("#1L val #2L: #3F = builder.#2L", prefix, memberName, memberSymbol)
}

if (memberShape.isNonBlankInStruct) {
writer
.indent()
.write(
""".apply { require(isNotBlank()) { "A non-blank value must be provided for #L" } }""",
memberName,
)
.dedent()
}
}

private val MemberShape.isRequiredInStruct
get() =
hasTrait<HttpLabelTrait>() || isRequired

private val MemberShape.isNonBlankInStruct: Boolean
get() =
ctx.model.expectShape(target).isStringShape &&
getTrait<LengthTrait>()?.min?.getOrNull()?.takeIf { it > 0 } != null

private fun renderCompanionObject() {
writer.withBlock("public companion object {", "}") {
write("public operator fun invoke(block: Builder.() -> #Q): #Q = Builder().apply(block).build()", KotlinTypes.Unit, symbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.lang.toEscapedLiteral
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.shapes.*
Expand Down Expand Up @@ -287,6 +288,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val pathBindings = requestBindings.filter { it.location == HttpBinding.Location.LABEL }

if (pathBindings.isNotEmpty()) {
// One of the few places we generate client side validation
// httpLabel bindings must be non-null
httpTrait.uri.segments.filter { it.isLabel || it.isGreedyLabel }.forEach { segment ->
val binding = pathBindings.find {
it.memberName == segment.content
} ?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}`")

val memberSymbol = ctx.symbolProvider.toSymbol(binding.member)
if (memberSymbol.isNullable) {
writer.write("""requireNotNull(input.#1L) { "#1L is bound to the URI and must not be null" }""", binding.member.defaultName())
}

// check length trait if applicable
renderNonBlankGuard(ctx, binding.member, writer)
}

writer.openBlock("val pathSegments = listOf<String>(", ")") {
httpTrait.uri.segments.forEach { segment ->
if (segment.isLabel || segment.isGreedyLabel) {
Expand All @@ -297,21 +314,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {

// shape must be string, number, boolean, or timestamp
val targetShape = ctx.model.expectShape(binding.member.target)
val memberSymbol = ctx.symbolProvider.toSymbol(binding.member)
val identifier = if (targetShape.isTimestampShape) {
writer.addImport(RuntimeTypes.Core.TimestampFormat)
val tsFormat = resolver.determineTimestampFormat(
binding.member,
HttpBinding.Location.LABEL,
defaultTimestampFormat,
)
val tsLabel = formatInstant("input.${binding.member.defaultName()}?", tsFormat, forceString = true)
val nullCheck = if (memberSymbol.isNullable) "?" else ""
val tsLabel = formatInstant("input.${binding.member.defaultName()}$nullCheck", tsFormat, forceString = true)
tsLabel
} else {
"input.${binding.member.defaultName()}"
}

val encodeSymbol = RuntimeTypes.Http.Util.encodeLabel
writer.addImport(encodeSymbol)
val encodeFn = if (segment.isGreedyLabel) {
writer.format("#T(greedy = true)", encodeSymbol)
} else {
Expand Down Expand Up @@ -362,6 +380,9 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
writer.write("append(#S, #S)", key, value)
}

// render length check if applicable
queryBindings.forEach { binding -> renderNonBlankGuard(ctx, binding.member, writer) }

renderStringValuesMapParameters(ctx, queryBindings, writer)

queryMapBindings.forEach {
Expand Down Expand Up @@ -1001,3 +1022,14 @@ fun OperationShape.errorHandler(settings: KotlinSettings, block: SymbolRenderer)
definitionFile = "${deserializerName()}.kt"
renderBy = block
}

private fun renderNonBlankGuard(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: KotlinWriter) {
if (member.isNonBlankInStruct(ctx)) {
val memberSymbol = ctx.symbolProvider.toSymbol(member)
val nullCheck = if (memberSymbol.isNullable) "?" else ""
writer.write("""require(input.#1L$nullCheck.isNotBlank() == true) { "#1L is bound to the URI and must be a non-blank value" }""", member.defaultName())
}
}
private fun MemberShape.isNonBlankInStruct(ctx: ProtocolGenerator.GenerationContext): Boolean =
ctx.model.expectShape(target).isStringShape &&
getTrait<LengthTrait>()?.min?.getOrNull()?.takeIf { it > 0 } != null
Original file line number Diff line number Diff line change
Expand Up @@ -361,66 +361,6 @@ class StructureGeneratorTest {
generated2.shouldContainOnlyOnceWithDiff(expected2)
}

@Test
fun `it handles required HTTP fields in initializers`() {
val model = """
@http(method: "POST", uri: "/foo/{bar}/{baz}")
operation Foo {
input: FooRequest
}
structure FooRequest {
@required
@httpLabel
bar: String,
@httpLabel
@required
baz: Integer,
@httpPayload
qux: String,
@required
@httpQuery("quux")
quux: Boolean,
@httpQuery("corge")
corge: String,
@required
@length(min: 0)
@httpQuery("grault")
grault: String,
@required
@length(min: 3)
@httpQuery("garply")
garply: String
}
""".prependNamespaceAndService(operations = listOf("Foo")).toSmithyModel()

val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model)
val writer = KotlinWriter(TestModelDefault.NAMESPACE)
val struct = model.expectShape<StructureShape>("com.test#FooRequest")
val renderingCtx = RenderingContext(writer, struct, model, provider, model.defaultSettings())
StructureGenerator(renderingCtx).render()

val generated = writer.toString()
val expected = """
public class FooRequest private constructor(builder: Builder) {
public val bar: kotlin.String = requireNotNull(builder.bar) { "A non-null value must be provided for bar" }
public val baz: kotlin.Int = requireNotNull(builder.baz) { "A non-null value must be provided for baz" }
public val corge: kotlin.String? = builder.corge
public val garply: kotlin.String = requireNotNull(builder.garply) { "A non-null value must be provided for garply" }
.apply { require(isNotBlank()) { "A non-blank value must be provided for garply" } }
public val grault: kotlin.String = requireNotNull(builder.grault) { "A non-null value must be provided for grault" }
public val quux: kotlin.Boolean = requireNotNull(builder.quux) { "A non-null value must be provided for quux" }
public val qux: kotlin.String? = builder.qux
""".formatForTest(indent = "")
generated.shouldContainOnlyOnceWithDiff(expected)
}

@Test
fun `it handles required query params in initializers`() {
val model = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ internal class EnumInputOperationSerializer: HttpSerialize<EnumInputRequest> {
fun itSerializesOperationInputsWithTimestamps() {
val contents = getTransformFileContents("TimestampInputOperationSerializer.kt")
contents.assertBalancedBracesAndParens()
val tsLabel = "\${input.tsLabel?.format(TimestampFormat.ISO_8601)}" // workaround for raw strings not being able to contain escapes
val tsLabel = "\${input.tsLabel.format(TimestampFormat.ISO_8601)}" // workaround for raw strings not being able to contain escapes
val expectedContents = """
internal class TimestampInputOperationSerializer: HttpSerialize<TimestampInputRequest> {
override suspend fun serialize(context: ExecutionContext, input: TimestampInputRequest): HttpRequestBuilder {
Expand Down Expand Up @@ -510,4 +510,74 @@ internal class SmokeTestOperationDeserializer: HttpDeserialize<SmokeTestResponse
"""
contents.shouldContainOnlyOnceWithDiff(expected)
}

@Test
fun itValidatesRequiredAndNonBlankURIBindings() {
val model = """
@http(method: "POST", uri: "/foo/{bar}/{baz}")
operation Foo {
input: FooRequest
}
@input
structure FooRequest {
@required
@length(min: 3)
@httpLabel
bar: String,
@httpLabel
@required
baz: Integer,
@httpPayload
qux: String,
@required
@httpQuery("quux")
quux: Boolean,
@httpQuery("corge")
corge: String,
@required
@length(min: 0)
@httpQuery("grault")
grault: String,
@required
@length(min: 3)
@httpQuery("garply")
garply: String
}
""".prependNamespaceAndService(operations = listOf("Foo"))
.toSmithyModel()

val contents = getTransformFileContents("FooOperationSerializer.kt", model)

val label1 = "\${input.bar}"
val label2 = "\${input.baz}"
val quux = "\${input.quux}"
val expected = """
requireNotNull(input.bar) { "bar is bound to the URI and must not be null" }
require(input.bar?.isNotBlank() == true) { "bar is bound to the URI and must be a non-blank value" }
requireNotNull(input.baz) { "baz is bound to the URI and must not be null" }
val pathSegments = listOf<String>(
"foo",
"$label1".encodeLabel(),
"$label2".encodeLabel(),
)
path = pathSegments.joinToString(separator = "/", prefix = "/")
parameters {
require(input.garply?.isNotBlank() == true) { "garply is bound to the URI and must be a non-blank value" }
if (input.corge != null) append("corge", input.corge)
if (input.garply != null) append("garply", input.garply)
if (input.grault != null) append("grault", input.grault)
if (input.quux != null) append("quux", "$quux")
}
"""

contents.shouldContainOnlyOnceWithDiff(expected)
}
}

0 comments on commit 0b8b0b1

Please sign in to comment.