Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor XML deserialize #1233

Merged
merged 11 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changes/119ee420-38a5-4974-922e-29cb11de02d0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "119ee420-38a5-4974-922e-29cb11de02d0",
"type": "bugfix",
"description": "Refactor XML deserialization to handle flat collections",
"issues": [
"awslabs/aws-sdk-kotlin#1220"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,32 @@ import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.traits.UnwrappedXmlOutput
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.transform.ModelTransformer

/**
* Applies the [UnwrappedXmlOutput] custom-made [annotation trait](https://smithy.io/2.0/spec/model.html?highlight=annotation#annotation-traits) to structures
* whose operation is annotated with `S3UnwrappedXmlOutput` trait to mark when special unwrapped xml output deserialization is required.
* Applies the custom [UnwrappedXmlOutput]
* [annotation trait](https://smithy.io/2.0/spec/model.html?highlight=annotation#annotation-traits) to operations
* annotated with `S3UnwrappedXmlOutput` trait to mark when special unwrapped xml output deserialization is required.
*/
class UnwrappedXmlOutputIntegration : KotlinIntegration {
override fun enabledForService(model: Model, settings: KotlinSettings): Boolean =
model.expectShape<ServiceShape>(settings.service).isS3

override fun preprocessModel(model: Model, settings: KotlinSettings): Model {
val unwrappedStructures = model
val unwrappedOperations = model
.operationShapes
.filter { it.hasTrait<S3UnwrappedXmlOutputTrait>() }
.map { it.outputShape }
.map { it.id }
.toSet()

return ModelTransformer
.create()
.mapShapes(model) { shape ->
when {
shape.id in unwrappedStructures ->
(shape as StructureShape).toBuilder().addTrait(UnwrappedXmlOutput()).build()
shape.id in unwrappedOperations ->
(shape as OperationShape).toBuilder().addTrait(UnwrappedXmlOutput()).build()
else -> shape
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingPr
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.*

Expand Down Expand Up @@ -68,24 +67,6 @@ private class AwsQuerySerdeFormUrlDescriptorGenerator(
member.hasTrait<XmlFlattenedTrait>()
}

private class AwsQuerySerdeXmlDescriptorGenerator(
ctx: RenderingContext<Shape>,
memberShapes: List<MemberShape>? = null,
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {

override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
val traits = super.getObjectDescriptorTraits().toMutableList()

if (objectShape.hasTrait<OperationOutput>()) {
traits.removeIf { it.symbol == RuntimeTypes.Serde.SerdeXml.XmlSerialName }
val serialName = objectShape.changeNameSuffix("Response" to "Result")
traits.add(RuntimeTypes.Serde.SerdeXml.XmlSerialName, serialName.dq())
}

return traits
}
}

private class AwsQuerySerializerGenerator(
private val protocolGenerator: AwsQuery,
) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
Expand All @@ -98,50 +79,76 @@ private class AwsQuerySerializerGenerator(
}

private class AwsQueryXmlParserGenerator(
private val protocolGenerator: AwsQuery,
) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {

override fun descriptorGenerator(
ctx: ProtocolGenerator.GenerationContext,
shape: Shape,
members: List<MemberShape>,
writer: KotlinWriter,
): XmlSerdeDescriptorGenerator = AwsQuerySerdeXmlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members)

override fun renderDeserializeOperationBody(
ctx: ProtocolGenerator.GenerationContext,
op: OperationShape,
documentMembers: List<MemberShape>,
writer: KotlinWriter,
) {
writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer)
unwrapOperationResponseBody(op.id.name, writer)
val shape = ctx.model.expectShape(op.output.get())
renderDeserializerBody(ctx, shape, documentMembers, writer)
}
protocolGenerator: AwsQuery,
) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {

/**
* Unwraps the response body as specified by
* https://awslabs.github.io/smithy/1.0/spec/aws/aws-query-protocol.html#response-serialization so that the
* deserializer is in the correct state.
*
* ```
* <SomeOperationResponse>
* <SomeOperationResult>
* <-- SAME AS REST XML -->
* </SomeOperationResult>
*</SomeOperationResponse>
* ```
*/
private fun unwrapOperationResponseBody(
operationName: String,
override fun unwrapOperationBody(
ctx: ProtocolGenerator.GenerationContext,
serdeCtx: SerdeCtx,
op: OperationShape,
writer: KotlinWriter,
) {
writer.write("// begin unwrap response wrapper")
.write("val resultDescriptor = #T(#T.Struct, #T(#S))", RuntimeTypes.Serde.SdkFieldDescriptor, RuntimeTypes.Serde.SerialKind, RuntimeTypes.Serde.SerdeXml.XmlSerialName, "${operationName}Result")
.withBlock("val wrapperDescriptor = #T.build {", "}", RuntimeTypes.Serde.SdkObjectDescriptor) {
write("trait(#T(#S))", RuntimeTypes.Serde.SerdeXml.XmlSerialName, "${operationName}Response")
write("#T(resultDescriptor)", RuntimeTypes.Serde.field)
): SerdeCtx {
val operationName = op.id.getName(ctx.service)

val unwrapAwsQueryOperation = buildSymbol {
name = "unwrapAwsQueryResponse"
namespace = ctx.settings.pkg.serde
definitionFile = "AwsQueryUtil.kt"
renderBy = { writer ->

writer.withBlock(
"internal fun $name(root: #1T, operationName: #2T): #1T {",
"}",
RuntimeTypes.Serde.SerdeXml.XmlTagReader,
KotlinTypes.String,
) {
write("val responseWrapperName = \"\${operationName}Response\"")
write("val resultWrapperName = \"\${operationName}Result\"")
withBlock(
"if (root.tagName != responseWrapperName) {",
"}",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid root, expected \$responseWrapperName; found `\${root.tag}`")
}

write("val resultTag = ${serdeCtx.tagReader}.nextTag()")
withBlock(
"if (resultTag == null || resultTag.tagName != resultWrapperName) {",
"}",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid result, expected \$resultWrapperName; found `\${resultTag?.tag}`")
}

write("return resultTag")
}
}
.write("")
// abandon the iterator, this only occurs at the top level operational output
.write("val wrapper = deserializer.#T(wrapperDescriptor)", RuntimeTypes.Serde.deserializeStruct)
.withBlock("if (wrapper.findNextFieldIndex() != resultDescriptor.index) {", "}") {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "failed to unwrap $operationName response")
}
.write("// end unwrap response wrapper")
.write("")
}

writer.write("val unwrapped = #T(#L, #S)", unwrapAwsQueryOperation, serdeCtx.tagReader, operationName)

return SerdeCtx("unwrapped")
}

override fun unwrapOperationError(
ctx: ProtocolGenerator.GenerationContext,
serdeCtx: SerdeCtx,
errorShape: StructureShape,
writer: KotlinWriter,
): SerdeCtx {
writer.write("val errReader = #T(${serdeCtx.tagReader})", RestXmlErrors.wrappedErrorResponseDeserializer(ctx))
return SerdeCtx("errReader")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@ package software.amazon.smithy.kotlin.codegen.aws.protocols

import software.amazon.smithy.aws.traits.protocols.Ec2QueryNameTrait
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AbstractQueryFormUrlSerializerGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RenderingContext
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.model.changeNameSuffix
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.getTrait
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.XmlNameTrait

Expand Down Expand Up @@ -73,24 +72,6 @@ private class Ec2QuerySerdeFormUrlDescriptorGenerator(
targetShape.type == ShapeType.LIST
}

private class Ec2QuerySerdeXmlDescriptorGenerator(
ctx: RenderingContext<Shape>,
memberShapes: List<MemberShape>? = null,
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {

override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
val traits = super.getObjectDescriptorTraits().toMutableList()

if (objectShape.hasTrait<OperationOutput>()) {
traits.removeIf { it.symbol == RuntimeTypes.Serde.SerdeXml.XmlSerialName }
val serialName = objectShape.changeNameSuffix("Response" to "Result")
traits.add(RuntimeTypes.Serde.SerdeXml.XmlSerialName, serialName.dq())
}

return traits
}
}

private class Ec2QuerySerializerGenerator(
private val protocolGenerator: Ec2Query,
) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
Expand All @@ -104,13 +85,73 @@ private class Ec2QuerySerializerGenerator(
}

private class Ec2QueryParserGenerator(
private val protocolGenerator: Ec2Query,
) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {

override fun descriptorGenerator(
protocolGenerator: Ec2Query,
) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {
override fun unwrapOperationError(
ctx: ProtocolGenerator.GenerationContext,
shape: Shape,
members: List<MemberShape>,
serdeCtx: SerdeCtx,
errorShape: StructureShape,
writer: KotlinWriter,
): XmlSerdeDescriptorGenerator = Ec2QuerySerdeXmlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members)
): SerdeCtx {
val unwrapFn = unwrapErrorResponse(ctx)
writer.write("val errReader = #T(${serdeCtx.tagReader})", unwrapFn)
return SerdeCtx("errReader")
}

/**
* Error deserializer for a wrapped error response
*
* ```
* <Response>
* <Errors>
* <Error>
* <-- DATA -->>
* </Error>
* </Errors>
* </Response>
* ```
*
* See https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization
*/
private fun unwrapErrorResponse(ctx: ProtocolGenerator.GenerationContext): Symbol = buildSymbol {
name = "unwrapXmlErrorResponse"
namespace = ctx.settings.pkg.serde
definitionFile = "XmlErrorUtils.kt"
renderBy = { writer ->
writer.dokka("Handle [wrapped](https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization) error responses")
writer.withBlock(
"internal fun $name(root: #1T): #1T {",
"}",
RuntimeTypes.Serde.SerdeXml.XmlTagReader,
) {
withBlock(
"if (root.tagName != #S) {",
"}",
"Response",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid root, expected <Response>; found `\${root.tag}`")
}

write("val errorsTag = root.nextTag()")
withBlock(
"if (errorsTag == null || errorsTag.tagName != #S) {",
"}",
"Errors",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid error, expected <Errors>; found `\${errorsTag?.tag}`")
}

write("val errTag = errorsTag.nextTag()")
withBlock(
"if (errTag == null || errTag.tagName != #S) {",
"}",
"Error",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid error, expected <Error>; found `\${errTag?.tag}`")
}

write("return errTag")
}
}
}
}
Loading
Loading