From 3578d7f671d70d982f9e7b88ade5907e819cec5b Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 19 Feb 2024 15:52:53 -0500 Subject: [PATCH 01/25] add Tag scoped reader abstraction --- .../runtime/collections/CollectionExt.kt | 17 ++ .../kotlin/runtime/serde/xml/TagReader.kt | 127 +++++++++++++++ .../runtime/serde/xml/XmlStreamReader.kt | 5 + .../deserialization/LexingXmlStreamReader.kt | 28 ++-- .../kotlin/runtime/serde/xml/TagReaderTest.kt | 147 ++++++++++++++++++ .../runtime/serde/xml/XmlStreamReaderTest.kt | 24 ++- 6 files changed, 332 insertions(+), 16 deletions(-) create mode 100644 runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt create mode 100644 runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt create mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt new file mode 100644 index 000000000..e57fefbee --- /dev/null +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt @@ -0,0 +1,17 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.collections + +/** + * Creates a new list or appends to an existing one if not null. + * + * If [dest] is null this function creates a new list with element [x] and returns it. + * Otherwise, it appends [x] to [dest] and returns the given [dest] list. + */ +public fun createOrAppend(dest: MutableList?, x: T): MutableList { + if (dest == null) return mutableListOf(x) + dest.add(x) + return dest +} diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt new file mode 100644 index 000000000..60112b983 --- /dev/null +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt @@ -0,0 +1,127 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde.xml + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.io.Closeable +import aws.smithy.kotlin.runtime.serde.DeserializationException + +/** + * An [XmlStreamReader] scoped to reading a single XML element [startTag] + * [TagReader] provides a "tag" scoped view into an XML document. Methods return + * `null` when the current tag has been exhausted. + */ +@InternalApi +public class TagReader( + public val startTag: XmlToken.BeginElement, + private val reader: XmlStreamReader, +) : Closeable { + private var last: TagReader? = null + private var closed = false + + public fun nextToken(): XmlToken? { + if (closed) return null + val peek = reader.peek() + if (peek.terminates(startTag)) { + // consume it and close the tag reader + reader.nextToken() + closed = true + return null + } + return reader.nextToken() + } + + public fun skipNext() { + if (closed) return + reader.skipNext() + } + + public fun skipCurrent() { + if (closed) return + reader.skipCurrent() + } + + override fun close(): Unit = drop() + + public fun drop() { + do { + val tok = nextToken() + } while (tok != null) + // // consume the end token for this element + // // FIXME - consuming the next token that ends this messes up the subtree reader state, `nextToken()` will now start + // // to return more tokens + // val next = parent.peek() + // if (next.terminates(startElement)) { + // parent.nextToken() + // } + } + + public fun nextTag(): TagReader? { + last?.drop() + + var cand = nextToken() + while (cand != null && cand !is XmlToken.BeginElement) { + cand = nextToken() + } + + val nextTok = cand as? XmlToken.BeginElement + + return nextTok?.tagReader(reader).also { newScope -> + last = newScope + } + } +} + +public fun XmlStreamReader.root(): TagReader { + val start = seek() ?: error("expected start tag: last = $lastToken") + return start.tagReader(this) +} + +/** + * Create a new reader scoped to this element. + */ +@InternalApi +public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): TagReader { + val start = reader.lastToken as? XmlToken.BeginElement ?: error("expected start tag found ${reader.lastToken}") + check(name == start.name) { "expected start tag $name but current reader state is on ${start.name}" } + return TagReader(this, reader) +} + +// @InternalApi +// public fun XmlToken.BeginElement.decode(reader: XmlStreamReader, block: TagReader.() -> T): T { +// val scoped = tagReader(reader) +// val result = block(scoped) +// // exhaust this reader +// scoped.drop() +// return result +// } + +/** + * Consume the next token and map the data value from it using [transform] + * + * If the next token is not [XmlToken.Text] an exception will be thrown + */ +public fun TagReader.map(transform: (String) -> T): T = + transform(text()) + +public fun TagReader.text(): String = + when (val next = nextToken()) { + is XmlToken.Text -> next.value ?: "" + null, is XmlToken.EndElement -> "" + else -> throw DeserializationException("expected XmlToken.Text element, found $next") + } + +private fun TagReader.mapOrThrow(expected: String, mapper: (String) -> T?): T = + map { raw -> + mapper(raw) ?: throw DeserializationException("could not deserialize $raw as $expected for tag ${this.startTag}") + } + +public fun TagReader.readInt(): Int = mapOrThrow("Int", String::toIntOrNull) +public fun TagReader.readShort(): Short = mapOrThrow("Short", String::toShortOrNull) +public fun TagReader.readLong(): Long = mapOrThrow("Long", String::toLongOrNull) +public fun TagReader.readFloat(): Float = mapOrThrow("Float", String::toFloatOrNull) +public fun TagReader.readDouble(): Double = mapOrThrow("Double", String::toDoubleOrNull) +public fun TagReader.readByte(): Byte = mapOrThrow("Byte") { it.toIntOrNull()?.toByte() } +public fun TagReader.readBoolean(): Boolean = mapOrThrow("Boolean", String::toBoolean) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt index a005be4c4..71af12869 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt @@ -58,6 +58,11 @@ public interface XmlStreamReader { */ public fun skipNext() + /** + * Recursively skip the current token. Meant for discarding unwanted/unrecognized nodes in an XML document + */ + public fun skipCurrent() + /** * Peek at the next token type. Successive calls will return the same value, meaning there is only one * look-ahead at any given time during the parsing of input data. diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt index 1c2d648f6..35d28cd71 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt @@ -42,16 +42,18 @@ public class LexingXmlStreamReader(private val source: XmlLexer) : XmlStreamRead override fun skipNext() { val peekToken = peek(1) ?: return val startDepth = peekToken.depth + scanUntilDepth(startDepth, nextToken()) + } - tailrec fun scanUntilDepth(from: XmlToken?) { - when { - from == null || from is XmlToken.EndDocument -> return // End of document - from is XmlToken.EndElement && from.depth == startDepth -> return // Returned to original start depth - else -> scanUntilDepth(nextToken()) // Keep scannin'! - } + private tailrec fun scanUntilDepth(startDepth: Int, from: XmlToken?) { + when { + from == null || from is XmlToken.EndDocument -> return // End of document + from is XmlToken.EndElement && from.depth == startDepth -> return // Returned to original start depth + else -> scanUntilDepth(startDepth, nextToken()) // Keep scannin'! } - - scanUntilDepth(nextToken()) + } + override fun skipCurrent() { + scanUntilDepth(lastToken?.depth ?: 0, lastToken) } override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = @@ -110,6 +112,8 @@ private class ChildXmlStreamReader( override fun skipNext() = parent.skipNext() + override fun skipCurrent() = parent.skipCurrent() + override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = parent.subTreeReader(subtreeStartDepth) } @@ -118,17 +122,19 @@ private class ChildXmlStreamReader( * An empty XML stream reader that trivially returns `null` for all [nextToken] and [peek] invocations. * @param parent The [LexingXmlStreamReader] on which this child reader is based. */ -private class EmptyXmlStreamReader(private val parent: XmlStreamReader) : XmlStreamReader { +private class EmptyXmlStreamReader(private val parent: XmlStreamReader?) : XmlStreamReader { override val lastToken: XmlToken? - get() = parent.lastToken + get() = parent?.lastToken override fun nextToken(): XmlToken? = null override fun peek(index: Int): XmlToken? = null override fun skipNext() = Unit - + override fun skipCurrent() = Unit override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = this } private fun List.getOrNull(index: Int): T? = if (index < size) this[index] else null + +internal fun XmlStreamReader.emptyReader(parent: XmlStreamReader? = this): XmlStreamReader = EmptyXmlStreamReader(parent) diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt new file mode 100644 index 000000000..0ae02ea6d --- /dev/null +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt @@ -0,0 +1,147 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde.xml + +import kotlin.test.* + +class TagReaderTest { + + @Test + fun testNextTag() { + // inner b could be confused as closing the outer b if depth isn't tracked properly + val payload = """ + + + + + + + + + more + + """.encodeToByteArray() + val scoped = xmlStreamReader(payload).root() + val expected = listOf("a", "b", "c", "d") + .map { XmlToken.BeginElement(2, it) } + + expected.forEach { expectedStartTag -> + val tagReader = assertNotNull(scoped.nextTag()) + assertEquals(expectedStartTag, tagReader.startTag) + tagReader.drop() + } + } + + @Test + fun testNextTagScope() { + // test scope of each tag reader + val payload = """ + + + 1 + 2 + + + 3 + 4 + + + + abc + + + """.encodeToByteArray() + val scoped = xmlStreamReader(payload).root() + assertEquals(XmlToken.BeginElement(1, "Root"), scoped.startTag) + + val s1 = assertNotNull(scoped.nextTag()) + assertEquals(XmlToken.BeginElement(2, "Child1"), s1.startTag) + val s1Elements = listOf( + XmlToken.BeginElement(3, "x"), + XmlToken.Text(3, "1"), + XmlToken.EndElement(3, "x"), + XmlToken.BeginElement(3, "y"), + XmlToken.Text(3, "2"), + XmlToken.EndElement(3, "y"), + ) + assertEquals(s1Elements, s1.allTokens()) + + val s2 = assertNotNull(scoped.nextTag()) + assertEquals(XmlToken.BeginElement(2, "Child2"), s2.startTag) + + val aReader = assertNotNull(s2.nextTag()) + assertEquals(XmlToken.BeginElement(3, "a"), aReader.startTag) + assertNull(aReader.nextTag()) + + val bReader = assertNotNull(s2.nextTag()) + assertEquals(XmlToken.BeginElement(3, "b"), bReader.startTag) + assertEquals(XmlToken.Text(3, "4"), bReader.nextToken()) + assertNull(bReader.nextToken()) + bReader.drop() + + // self close token behavior + val selfCloseReader = assertNotNull(scoped.nextTag()) + assertEquals(emptyList(), selfCloseReader.allTokens()) + selfCloseReader.drop() + + val s4 = assertNotNull(scoped.nextTag()) + assertEquals(XmlToken.BeginElement(2, "Child4"), s4.startTag) + } + + @Test + fun testData() { + val payload = """ + + + 1 + 2 + + + this is an a + decoder should skip + + ignored a + ignored b + ignored c + + + + + + + + """.encodeToByteArray() + + val decoder = xmlStreamReader(payload).root() + loop@while (true) { + val curr = decoder.nextTag() ?: break@loop + when (curr.startTag.name.tag) { + "Child1" -> { + assertEquals(1, curr.nextTag()?.readInt()) + assertEquals(2, curr.nextTag()?.readInt()) + } + "Child2" -> { + assertEquals("this is an a", curr.nextTag()?.text()) + // intentionally ignore the next tag and don't consume the entire child subtree + } + "Child4" -> assertEquals(" ", curr.nextTag()?.text()) + else -> {} + } + // consume the current tag entirely before trying to process the next + curr.drop() + } + } +} + +fun TagReader.allTokens(): List { + val tokenList = mutableListOf() + var nextToken: XmlToken? + do { + nextToken = this.nextToken() + if (nextToken != null) tokenList.add(nextToken) + } while (nextToken != null) + + return tokenList +} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt index fc1ef5dca..b4226d6e9 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt @@ -193,8 +193,7 @@ class XmlStreamReaderTest { assertEquals(expected, actual) } - @Test - fun itSkipsValuesRecursively() { + private fun skipTest(skipCurrent: Boolean) { val payload = """ 1> @@ -225,16 +224,31 @@ class XmlStreamReaderTest { nextToken() // end x } - val nt = reader.peek() - assertIs(nt) + val nt = if (skipCurrent) { + reader.nextToken() + } else { + reader.peek() + } + assertIs(nt) assertEquals("unknown", nt.name.local) - reader.skipNext() + + if (skipCurrent) { + reader.skipCurrent() + } else { + reader.skipNext() + } val y = reader.nextToken() as XmlToken.BeginElement assertEquals("y", y.name.local) } + @Test + fun itSkipsNextValuesRecursively() = skipTest(false) + + @Test + fun itSkipsCurrentValuesRecursively() = skipTest(true) + @Test fun itSkipsSimpleValues() { val payload = """ From 4dc2106c61f8d7a03b340823385f7e46c0e62d50 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 19 Feb 2024 23:15:52 -0500 Subject: [PATCH 02/25] bootstrap generated serde tests --- build.gradle.kts | 3 +- settings.gradle.kts | 3 +- .../json/SerdeBenchmarkJsonProtocol.kt | 20 --- .../xml/SerdeBenchmarkXmlProtocol.kt | 20 --- ...tlin.codegen.integration.KotlinIntegration | 2 - .../serde-benchmarks/build.gradle.kts | 3 +- .../model/countriesstates.smithy | 4 +- .../model/serde-protocols.smithy | 13 -- .../serde-benchmarks/model/twitter.smithy | 4 +- .../serde-codegen-support}/build.gradle.kts | 2 +- .../codegen/protocols/ProtocolSupplier.kt | 15 ++ .../protocols/SerdeProtocolGenerator.kt} | 4 +- .../protocols/json/SerdeJsonProtocol.kt | 22 +++ .../json/SerdeJsonProtocolGenerator.kt} | 8 +- .../codegen/protocols/xml/SerdeXmlProtocol.kt | 23 +++ .../xml/SerdeXmlProtocolGenerator.kt} | 8 +- ...tlin.codegen.integration.KotlinIntegration | 1 + ...re.amazon.smithy.model.traits.TraitService | 2 + .../main/resources/META-INF/smithy/manifest | 1 + .../META-INF/smithy/protocols.smithy | 13 ++ tests/codegen/serde-tests/.gitignore | 1 + tests/codegen/serde-tests/build.gradle.kts | 88 ++++++++++ tests/codegen/serde-tests/model/shared.smithy | 154 ++++++++++++++++++ tests/codegen/serde-tests/model/xml.smithy | 18 ++ tests/codegen/serde-tests/smithy-build.json | 31 ++++ .../kotlin/tests/serde/XmlStructTest.kt | 44 +++++ 26 files changed, 434 insertions(+), 73 deletions(-) delete mode 100644 tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt delete mode 100644 tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt delete mode 100644 tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration delete mode 100644 tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy rename tests/{benchmarks/serde-benchmarks-codegen => codegen/serde-codegen-support}/build.gradle.kts (81%) create mode 100644 tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt rename tests/{benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt => codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt} (92%) create mode 100644 tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt rename tests/{benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt => codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt} (71%) create mode 100644 tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt rename tests/{benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt => codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt} (78%) create mode 100644 tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration create mode 100644 tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService create mode 100644 tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest create mode 100644 tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy create mode 100644 tests/codegen/serde-tests/.gitignore create mode 100644 tests/codegen/serde-tests/build.gradle.kts create mode 100644 tests/codegen/serde-tests/model/shared.smithy create mode 100644 tests/codegen/serde-tests/model/xml.smithy create mode 100644 tests/codegen/serde-tests/smithy-build.json create mode 100644 tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index cc54aef71..6e1e0c77a 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -130,7 +130,8 @@ apiValidation { "channel-benchmarks", "http-benchmarks", "serde-benchmarks", - "serde-benchmarks-codegen", + "serde-codegen-support", + "serde-tests", "nullability-tests", "paginator-tests", "waiter-tests", diff --git a/settings.gradle.kts b/settings.gradle.kts index 4608e6c65..e9e3d91d2 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -73,11 +73,12 @@ include(":tests") include(":tests:benchmarks:aws-signing-benchmarks") include(":tests:benchmarks:channel-benchmarks") include(":tests:benchmarks:http-benchmarks") -include(":tests:benchmarks:serde-benchmarks-codegen") include(":tests:benchmarks:serde-benchmarks") include(":tests:compile") include(":tests:codegen:nullability-tests") include(":tests:codegen:paginator-tests") +include(":tests:codegen:serde-tests") +include(":tests:codegen:serde-codegen-support") include(":tests:codegen:waiter-tests") include(":tests:integration:slf4j-1x-consumer") include(":tests:integration:slf4j-2x-consumer") diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt b/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt deleted file mode 100644 index 58dd3abb7..000000000 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.kotlin.codegen.protocols.json - -import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration -import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator -import software.amazon.smithy.model.shapes.ShapeId - -/** - * Dummy protocol for use in serde-benchmark project models. Generates JSON based serializers/deserializers - */ -class SerdeBenchmarkJsonProtocol : KotlinIntegration { - companion object { - val ID: ShapeId = ShapeId.from("aws.benchmarks.protocols#serdeBenchmarkJson") - } - - override val protocolGenerators: List = listOf(SerdeBenchmarkJsonProtocolGenerator) -} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt b/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt deleted file mode 100644 index c2d940daf..000000000 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.kotlin.codegen.protocols.xml - -import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration -import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator -import software.amazon.smithy.model.shapes.ShapeId - -/** - * Dummy protocol for use in serde-benchmark project models. Generates XML-based serializers/deserializers. - */ -class SerdeBenchmarkXmlProtocol : KotlinIntegration { - companion object { - val ID: ShapeId = ShapeId.from("aws.benchmarks.protocols#serdeBenchmarkXml") - } - - override val protocolGenerators: List = listOf(SerdeBenchmarkXmlProtocolGenerator) -} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration deleted file mode 100644 index 88d7b7750..000000000 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +++ /dev/null @@ -1,2 +0,0 @@ -software.amazon.smithy.kotlin.codegen.protocols.json.SerdeBenchmarkJsonProtocol -software.amazon.smithy.kotlin.codegen.protocols.xml.SerdeBenchmarkXmlProtocol diff --git a/tests/benchmarks/serde-benchmarks/build.gradle.kts b/tests/benchmarks/serde-benchmarks/build.gradle.kts index 61e89b8ea..a604a67ec 100644 --- a/tests/benchmarks/serde-benchmarks/build.gradle.kts +++ b/tests/benchmarks/serde-benchmarks/build.gradle.kts @@ -75,13 +75,14 @@ afterEvaluate { val codegen by configurations.getting dependencies { - codegen(project(":tests:benchmarks:serde-benchmarks-codegen")) + codegen(project(":tests:codegen:serde-codegen-support")) codegen(libs.smithy.cli) codegen(libs.smithy.model) } tasks.generateSmithyProjections { smithyBuildConfigs.set(files("smithy-build.json")) + buildClasspath.set(codegen) } data class BenchmarkModel(val name: String) { diff --git a/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy b/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy index 25dcd16d4..88e5c06e9 100644 --- a/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy +++ b/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy @@ -2,9 +2,9 @@ $version: "1.0" namespace aws.benchmarks.countries_states -use aws.benchmarks.protocols#serdeBenchmarkXml +use aws.serde.protocols#serdeXml -@serdeBenchmarkXml +@serdeXml service CountriesStatesService { version: "2019-12-16", operations: [GetCountriesAndStates] diff --git a/tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy b/tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy deleted file mode 100644 index b14ea9bf9..000000000 --- a/tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy +++ /dev/null @@ -1,13 +0,0 @@ -$version: "1.0" - -namespace aws.benchmarks.protocols - -// dummy protocols just for benchmarking purposes - -@protocolDefinition -@trait -structure serdeBenchmarkJson{} - -@protocolDefinition -@trait -structure serdeBenchmarkXml{} diff --git a/tests/benchmarks/serde-benchmarks/model/twitter.smithy b/tests/benchmarks/serde-benchmarks/model/twitter.smithy index b11264123..d57f03115 100644 --- a/tests/benchmarks/serde-benchmarks/model/twitter.smithy +++ b/tests/benchmarks/serde-benchmarks/model/twitter.smithy @@ -2,9 +2,9 @@ $version: "1.0" namespace aws.benchmarks.twitter -use aws.benchmarks.protocols#serdeBenchmarkJson +use aws.serde.protocols#serdeJson -@serdeBenchmarkJson +@serdeJson service Twitter { version: "2019-12-16", operations: [GetFeed] diff --git a/tests/benchmarks/serde-benchmarks-codegen/build.gradle.kts b/tests/codegen/serde-codegen-support/build.gradle.kts similarity index 81% rename from tests/benchmarks/serde-benchmarks-codegen/build.gradle.kts rename to tests/codegen/serde-codegen-support/build.gradle.kts index 3a21d75ba..f5551afa4 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/build.gradle.kts +++ b/tests/codegen/serde-codegen-support/build.gradle.kts @@ -9,7 +9,7 @@ plugins { skipPublishing() -description = "Codegen support for serde-benchmarks project" +description = "Codegen support for serde related integration tests" dependencies { implementation(project(":codegen:smithy-kotlin-codegen")) diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt new file mode 100644 index 000000000..f08c06908 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt @@ -0,0 +1,15 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.protocols + +import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +import software.amazon.smithy.kotlin.codegen.protocols.json.SerdeJsonProtocolGenerator +import software.amazon.smithy.kotlin.codegen.protocols.xml.SerdeXmlProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator + +class ProtocolSupplier : KotlinIntegration { + override val protocolGenerators: List + get() = listOf(SerdeJsonProtocolGenerator, SerdeXmlProtocolGenerator) +} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt similarity index 92% rename from tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt rename to tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt index c8d3a1765..a1d85f729 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt @@ -13,7 +13,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.traits.TimestampFormatTrait -abstract class BenchmarkProtocolGenerator : HttpBindingProtocolGenerator() { +abstract class SerdeProtocolGenerator : HttpBindingProtocolGenerator() { abstract val contentTypes: ProtocolContentTypes override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS @@ -38,7 +38,7 @@ abstract class BenchmarkProtocolGenerator : HttpBindingProtocolGenerator() { RuntimeTypes.Core.ExecutionContext, RuntimeTypes.Http.HttpCall, ) { - write("error(\"not needed for benchmark tests\")") + write("error(\"not needed for codegen related tests\")") } } } diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt new file mode 100644 index 000000000..8f91830d9 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.protocols.json + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AnnotationTrait + +/** + * Dummy protocol for use in serde-benchmark project models. Generates JSON based serializers/deserializers + */ +class SerdeJsonProtocol : AnnotationTrait { + companion object { + val ID: ShapeId = ShapeId.from("aws.serde.protocols#serdeJson") + class Provider : AnnotationTrait.Provider(ID, ::SerdeJsonProtocol) + } + constructor(node: ObjectNode) : super(ID, node) + constructor() : this(Node.objectNode()) +} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt similarity index 71% rename from tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt rename to tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt index fe6ae04e4..30543899b 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt @@ -4,17 +4,17 @@ */ package software.amazon.smithy.kotlin.codegen.protocols.json -import software.amazon.smithy.kotlin.codegen.protocols.BenchmarkProtocolGenerator +import software.amazon.smithy.kotlin.codegen.protocols.SerdeProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.* import software.amazon.smithy.kotlin.codegen.rendering.serde.* import software.amazon.smithy.model.shapes.ShapeId /** - * Protocol generator for benchmark protocol [SerdeBenchmarkJsonProtocol] + * Protocol generator for benchmark protocol [SerdeJsonProtocol] */ -object SerdeBenchmarkJsonProtocolGenerator : BenchmarkProtocolGenerator() { +object SerdeJsonProtocolGenerator : SerdeProtocolGenerator() { override val contentTypes = ProtocolContentTypes.consistent("application/json") - override val protocol: ShapeId = SerdeBenchmarkJsonProtocol.ID + override val protocol: ShapeId = SerdeJsonProtocol.ID override fun structuredDataSerializer(ctx: ProtocolGenerator.GenerationContext): StructuredDataSerializerGenerator = JsonSerializerGenerator(this) diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt new file mode 100644 index 000000000..87e209dd6 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt @@ -0,0 +1,23 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.protocols.xml + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AnnotationTrait + +/** + * Dummy protocol for use in testing projects that need to test XML codegen. Generates XML-based serializers/deserializers. + */ +class SerdeXmlProtocol : AnnotationTrait { + companion object { + val ID: ShapeId = ShapeId.from("aws.serde.protocols#serdeXml") + class Provider : AnnotationTrait.Provider(ID, ::SerdeXmlProtocol) + } + + constructor(node: ObjectNode) : super(ID, node) + constructor() : this(Node.objectNode()) +} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt similarity index 78% rename from tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt rename to tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt index 9b421b6eb..2b0d14461 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt @@ -4,7 +4,7 @@ */ package software.amazon.smithy.kotlin.codegen.protocols.xml -import software.amazon.smithy.kotlin.codegen.protocols.BenchmarkProtocolGenerator +import software.amazon.smithy.kotlin.codegen.protocols.SerdeProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.* import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataSerializerGenerator @@ -13,11 +13,11 @@ import software.amazon.smithy.kotlin.codegen.rendering.serde.XmlSerializerGenera import software.amazon.smithy.model.shapes.ShapeId /** - * Protocol generator for benchmark protocol [SerdeBenchmarkXmlProtocol]. + * Protocol generator for testing [SerdeXmlProtocol]. */ -object SerdeBenchmarkXmlProtocolGenerator : BenchmarkProtocolGenerator() { +object SerdeXmlProtocolGenerator : SerdeProtocolGenerator() { override val contentTypes = ProtocolContentTypes.consistent("application/xml") - override val protocol: ShapeId = SerdeBenchmarkXmlProtocol.ID + override val protocol: ShapeId = SerdeXmlProtocol.ID override fun structuredDataParser(ctx: ProtocolGenerator.GenerationContext): StructuredDataParserGenerator = XmlParserGenerator(this, defaultTimestampFormat) diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration new file mode 100644 index 000000000..8d5752fb7 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration @@ -0,0 +1 @@ +software.amazon.smithy.kotlin.codegen.protocols.ProtocolSupplier diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService new file mode 100644 index 000000000..895c9d9c8 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService @@ -0,0 +1,2 @@ +software.amazon.smithy.kotlin.codegen.protocols.json.SerdeJsonProtocol$Companion$Provider +software.amazon.smithy.kotlin.codegen.protocols.xml.SerdeXmlProtocol$Companion$Provider \ No newline at end of file diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest new file mode 100644 index 000000000..31b96587d --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest @@ -0,0 +1 @@ +protocols.smithy \ No newline at end of file diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy new file mode 100644 index 000000000..2f4413170 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy @@ -0,0 +1,13 @@ +$version: "2.0" + +namespace aws.serde.protocols + +// dummy protocols just for testing/benchmarking purposes + +@protocolDefinition +@trait +structure serdeJson{} + +@protocolDefinition +@trait +structure serdeXml{} diff --git a/tests/codegen/serde-tests/.gitignore b/tests/codegen/serde-tests/.gitignore new file mode 100644 index 000000000..fc706388c --- /dev/null +++ b/tests/codegen/serde-tests/.gitignore @@ -0,0 +1 @@ +generated-src \ No newline at end of file diff --git a/tests/codegen/serde-tests/build.gradle.kts b/tests/codegen/serde-tests/build.gradle.kts new file mode 100644 index 000000000..1c70b1f12 --- /dev/null +++ b/tests/codegen/serde-tests/build.gradle.kts @@ -0,0 +1,88 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import aws.sdk.kotlin.gradle.codegen.smithyKotlinProjectionSrcDir +import aws.sdk.kotlin.gradle.dsl.skipPublishing + +plugins { + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.aws.kotlin.repo.tools.smithybuild) +} + +skipPublishing() + +val codegen by configurations.getting +dependencies { + codegen(project(":codegen:smithy-kotlin-codegen")) + codegen(project(":tests:codegen:serde-codegen-support")) + codegen(libs.smithy.cli) + codegen(libs.smithy.model) +} + +tasks.generateSmithyProjections { + smithyBuildConfigs.set(files("smithy-build.json")) + inputs.dir(project.layout.projectDirectory.dir("model")) + buildClasspath.set(codegen) +} + +val optinAnnotations = listOf("kotlin.RequiresOptIn", "aws.smithy.kotlin.runtime.InternalApi") +kotlin.sourceSets.all { + optinAnnotations.forEach { languageSettings.optIn(it) } +} + +tasks.test { + useJUnitPlatform() + testLogging { + events("passed", "skipped", "failed") + showStandardStreams = true + } +} + +dependencies { + compileOnly(project(":codegen:smithy-kotlin-codegen")) + + implementation(libs.kotlinx.coroutines.core) + implementation(project(":runtime:runtime-core")) + implementation(project(":runtime:serde")) + implementation(project(":runtime:serde:serde-json")) + implementation(project(":runtime:serde:serde-xml")) + + testImplementation(libs.kotlin.test.junit5) +} + +val generatedSrcDir = project.layout.projectDirectory.dir("generated-src/main/kotlin") + +val stageGeneratedSources = tasks.register("stageGeneratedSources") { + group = "codegen" + dependsOn(tasks.generateSmithyProjections) + outputs.dir(generatedSrcDir) + doLast { + listOf("xml", "json").forEach { projectionName -> + val fromDir = smithyBuild.smithyKotlinProjectionSrcDir(projectionName) + logger.info("copying from ${fromDir.get()} to $generatedSrcDir") + copy { + from(fromDir) + into(generatedSrcDir) + include("**/model/*.kt") + include("**/serde/*.kt") + exclude("**/auth/*.kt") + exclude("**/endpoints/**.kt") + exclude("**/serde/*OperationSerializer.kt") + exclude("**/serde/*OperationDeserializer.kt") + } + } + } +} + +kotlin.sourceSets.getByName("main") { + kotlin.srcDir(generatedSrcDir) +} + +tasks.withType { + dependsOn(stageGeneratedSources) +} + +tasks.clean.configure { + delete(project.layout.projectDirectory.dir("generated-src")) +} diff --git a/tests/codegen/serde-tests/model/shared.smithy b/tests/codegen/serde-tests/model/shared.smithy new file mode 100644 index 000000000..fb96f483c --- /dev/null +++ b/tests/codegen/serde-tests/model/shared.smithy @@ -0,0 +1,154 @@ +$version: "2.0" + +namespace aws.tests.serde.shared + +list StringList { + member: String, +} + +@sparse +list SparseStringList { + member: String +} + +map StringMap { + key: String, + value: String, +} + +map StringListMap { + key: String, + value: StringList +} + +map NestedStringMap { + key: String, + value: StringMap +} + +@sparse +map SparseStringMap { + key: String, + value: String, +} + +list NestedStringList { + member: StringList, +} + +list IntegerList { + member: Integer, +} + +@uniqueItems +list IntegerSet { + member: Integer, +} + +enum FooEnum { + FOO = "Foo" + BAZ = "Baz" + BAR = "Bar" + ONE = "1" + ZERO = "0" +} + +list FooEnumList { + member: FooEnum, +} + +map FooEnumMap { + key: String, + value: FooEnum, +} + +@timestampFormat("date-time") +timestamp DateTime + +@timestampFormat("epoch-seconds") +timestamp EpochSeconds + +@timestampFormat("http-date") +timestamp HttpDate + +intEnum IntegerEnum { + A = 1 + B = 2 + C = 3 +} + +list IntegerEnumList { + member: IntegerEnum +} + +map IntegerEnumMap { + key: String, + value: IntegerEnum +} + + +union Choice { + @xmlFlattened + @xmlName("flatmap") + flatMap: StringMap, + + normalMap: StringMap, + + sparseMap: SparseStringMap, + + // FIXME - doesn't work with current codegen + // listMap: StringListMap, + + // FIXME - doesn't work with current codegen + // nestedMap: NestedStringMap + + @xmlFlattened + @xmlName("flatlist") + flatList: StringList, + + normalList: StringList, + + sparseList: SparseStringList, + + // FIXME - doesn't work with current codegen + // nestedList: NestedStringList, + + str: String, + + enum: FooEnum, + + dateTime: DateTime, + epochTime: EpochSeconds, + httpTime: HttpDate, + + @xmlName("double") + fpDouble: Double, + + top: Top, + + blob: Blob, + + unit: Unit, + + // TODO - enum lists, timestamp lists, structure list, structure map, multiple flat lists interspersed (xml only) +} + +structure Top { + choice: Choice, + + strField: String, + + enumField: FooEnum, + + @xmlAttribute + extra: Long, + + @xmlName("prefix:local") + renamedWithPrefix: String, + + + // FIXME - move back to Choice when supported properly + listMap: StringListMap, + nestedMap: NestedStringMap + nestedList: NestedStringList, +} \ No newline at end of file diff --git a/tests/codegen/serde-tests/model/xml.smithy b/tests/codegen/serde-tests/model/xml.smithy new file mode 100644 index 000000000..ca425453e --- /dev/null +++ b/tests/codegen/serde-tests/model/xml.smithy @@ -0,0 +1,18 @@ +$version: "1.0" + +namespace aws.tests.serde.xml + +use aws.serde.protocols#serdeXml +use aws.tests.serde.shared#Top + +@serdeXml +service XmlService { + version: "2022-07-07", + operations: [TestOp] +} + +@http(uri: "/top", method: "POST") +operation TestOp { + input: Top, + output: Top, +} diff --git a/tests/codegen/serde-tests/smithy-build.json b/tests/codegen/serde-tests/smithy-build.json new file mode 100644 index 000000000..1ea49d608 --- /dev/null +++ b/tests/codegen/serde-tests/smithy-build.json @@ -0,0 +1,31 @@ +{ + "version": "1.0", + "sources": ["model"], + "projections": { + "xml": { + "transforms": [ + { + "name": "includeServices", + "args": { + "services": [ + "aws.tests.serde.xml#XmlService" + ] + } + } + ], + "plugins": { + "kotlin-codegen": { + "service": "aws.tests.serde.xml#XmlService", + "package": { + "name": "aws.smithy.kotlin.tests.serde.xml", + "version": "0.0.1" + }, + "build": { + "rootProject": false, + "generateDefaultBuildFiles": false + } + } + } + } + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt new file mode 100644 index 000000000..ed80a5062 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt @@ -0,0 +1,44 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer +import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer +import aws.smithy.kotlin.tests.serde.xml.model.FooEnum +import aws.smithy.kotlin.tests.serde.xml.model.Top +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeTopDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeTopDocument +import kotlin.test.Test +import kotlin.test.assertEquals + +class XmlStructTest { + @Test + fun testStructPrimitives() { + val expected = Top { + strField = "a string" + enumField = FooEnum.Bar + extra = 42 + } + + val payload = """ + + a string + Bar + + """.trimIndent().encodeToByteArray() + + val serializer = XmlSerializer() + serializeTopDocument(serializer, expected) + val actualPayload = serializer.toByteArray().decodeToString() + + val deserializer = XmlDeserializer(payload) + val actualDeserialized = deserializeTopDocument(deserializer) + assertEquals(expected, actualDeserialized) + + // TODO - use assertXmlStringsEqual from smithy-test + // TODO - figure out roundtrip structure + // TODO - turn into abstract base for XML vs JSON + } +} From f30302fa71f284786f19cbc650886acff1092b46 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Tue, 20 Feb 2024 13:42:33 -0500 Subject: [PATCH 03/25] add xml serde test suite --- .../runtime/smithy/test/XmlAssertions.kt | 2 +- tests/codegen/serde-tests/build.gradle.kts | 1 + tests/codegen/serde-tests/model/shared.smithy | 113 +++++------ tests/codegen/serde-tests/model/xml.smithy | 78 +++++++- .../kotlin/tests/serde/AbstractXmlTest.kt | 28 +++ .../smithy/kotlin/tests/serde/XmlListTest.kt | 123 ++++++++++++ .../smithy/kotlin/tests/serde/XmlMapTest.kt | 165 +++++++++++++++++ .../kotlin/tests/serde/XmlStructTest.kt | 110 +++++++++-- .../smithy/kotlin/tests/serde/XmlUnionTest.kt | 175 ++++++++++++++++++ 9 files changed, 718 insertions(+), 77 deletions(-) create mode 100644 tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt create mode 100644 tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt create mode 100644 tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt create mode 100644 tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt diff --git a/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt b/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt index f0c964a90..5d75333bb 100644 --- a/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt +++ b/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt @@ -14,7 +14,7 @@ import kotlin.test.assertEquals /** * Assert XML strings for equality ignoring key order */ -public suspend fun assertXmlStringsEqual(expected: String, actual: String) { +public fun assertXmlStringsEqual(expected: String, actual: String) { // parse into a dom representation and sort the dom into a canonical form for comparison val expectedNode = XmlNode.parse(expected.encodeToByteArray()).apply { toCanonicalForm() } val actualNode = XmlNode.parse(actual.encodeToByteArray()).apply { toCanonicalForm() } diff --git a/tests/codegen/serde-tests/build.gradle.kts b/tests/codegen/serde-tests/build.gradle.kts index 1c70b1f12..36e94a306 100644 --- a/tests/codegen/serde-tests/build.gradle.kts +++ b/tests/codegen/serde-tests/build.gradle.kts @@ -47,6 +47,7 @@ dependencies { implementation(project(":runtime:serde")) implementation(project(":runtime:serde:serde-json")) implementation(project(":runtime:serde:serde-xml")) + implementation(project(":runtime:smithy-test")) testImplementation(libs.kotlin.test.junit5) } diff --git a/tests/codegen/serde-tests/model/shared.smithy b/tests/codegen/serde-tests/model/shared.smithy index fb96f483c..87abe6612 100644 --- a/tests/codegen/serde-tests/model/shared.smithy +++ b/tests/codegen/serde-tests/model/shared.smithy @@ -87,68 +87,77 @@ map IntegerEnumMap { } -union Choice { - @xmlFlattened - @xmlName("flatmap") - flatMap: StringMap, +@mixin +structure PrimitiveTypesMixin { + strField: String, + byteField: Byte, + intField: Integer, + shortField: Short, + longField: Long, + floatField: Float, + doubleField: Double, + bigIntegerField: BigInteger, + bigDecimalField: BigDecimal, + boolField: Boolean, + blobField: Blob, + enumField: FooEnum, + intEnumField: IntegerEnum, + dateTimeField: DateTime, + epochTimeField: EpochSeconds, + httpTimeField: HttpDate, +} - normalMap: StringMap, +@mixin +union PrimitiveTypesUnionMixin { + strField: String, + byteField: Byte, + intField: Integer, + shortField: Short, + longField: Long, + floatField: Float, + doubleField: Double, + bigIntegerField: BigInteger, + bigDecimalField: BigDecimal, + boolField: Boolean, + blobField: Blob, + enumField: FooEnum, + intEnumField: IntegerEnum, + dateTimeField: DateTime, + epochTimeField: EpochSeconds, + httpTimeField: HttpDate, + unitField: Unit +} +@mixin +structure MapTypesMixin { + normalMap: StringMap, sparseMap: SparseStringMap, + nestedMap: NestedStringMap, + listMap: StringListMap, +} - // FIXME - doesn't work with current codegen - // listMap: StringListMap, - - // FIXME - doesn't work with current codegen - // nestedMap: NestedStringMap +@mixin +union MapTypesUnionMixin { + normalMap: StringMap, + sparseMap: SparseStringMap, + // FIXME - doesn't work with current codegen for unions + // nestedMap: NestedStringMap, +} - @xmlFlattened - @xmlName("flatlist") - flatList: StringList, +@mixin +structure ListTypesMixin { + normalList: StringList, + sparseList: SparseStringList, + nestedList: NestedStringList, +} +@mixin +union ListTypesUnionMixin { normalList: StringList, sparseList: SparseStringList, - // FIXME - doesn't work with current codegen + // FIXME - doesn't work with current codegen for unions // nestedList: NestedStringList, - - str: String, - - enum: FooEnum, - - dateTime: DateTime, - epochTime: EpochSeconds, - httpTime: HttpDate, - - @xmlName("double") - fpDouble: Double, - - top: Top, - - blob: Blob, - - unit: Unit, - - // TODO - enum lists, timestamp lists, structure list, structure map, multiple flat lists interspersed (xml only) } -structure Top { - choice: Choice, - - strField: String, - - enumField: FooEnum, - - @xmlAttribute - extra: Long, - - @xmlName("prefix:local") - renamedWithPrefix: String, - - - // FIXME - move back to Choice when supported properly - listMap: StringListMap, - nestedMap: NestedStringMap - nestedList: NestedStringList, -} \ No newline at end of file diff --git a/tests/codegen/serde-tests/model/xml.smithy b/tests/codegen/serde-tests/model/xml.smithy index ca425453e..a360f6958 100644 --- a/tests/codegen/serde-tests/model/xml.smithy +++ b/tests/codegen/serde-tests/model/xml.smithy @@ -1,9 +1,21 @@ -$version: "1.0" +$version: "2.0" namespace aws.tests.serde.xml use aws.serde.protocols#serdeXml -use aws.tests.serde.shared#Top +use aws.tests.serde.shared#PrimitiveTypesMixin +use aws.tests.serde.shared#ListTypesMixin +use aws.tests.serde.shared#MapTypesMixin +use aws.tests.serde.shared#PrimitiveTypesUnionMixin +use aws.tests.serde.shared#ListTypesUnionMixin +use aws.tests.serde.shared#MapTypesUnionMixin +use aws.tests.serde.shared#StringMap +use aws.tests.serde.shared#StringListMap +use aws.tests.serde.shared#NestedStringMap +use aws.tests.serde.shared#FooEnumMap +use aws.tests.serde.shared#IntegerList +use aws.tests.serde.shared#StringList +use aws.tests.serde.shared#NestedStringList @serdeXml service XmlService { @@ -13,6 +25,64 @@ service XmlService { @http(uri: "/top", method: "POST") operation TestOp { - input: Top, - output: Top, + input: StructType, + output: StructType, +} + +structure StructType with [PrimitiveTypesMixin, ListTypesMixin, MapTypesMixin] { + unionField: UnionType, + + recursive: StructType, + + @xmlAttribute + extra: Long, + + @xmlName("prefix:local") + renamedWithPrefix: String, + + @xmlFlattened + @xmlName("flatlist1") + flatList: StringList, + + @xmlFlattened + @xmlName("flatlist2") + secondFlatList: IntegerList + + @xmlFlattened + @xmlName("flatenummap") + flatEnumMap: FooEnumMap, + + renamedMemberList: RenamedMemberIntList + + renamedMemberMap: RenamedMap +} + +list RenamedMemberIntList { + @xmlName("item") + member: String +} + +map RenamedMap { + @xmlName("aKey") + key: String + + @xmlName("aValue") + value: String +} + +union UnionType with [PrimitiveTypesUnionMixin, ListTypesUnionMixin, MapTypesUnionMixin] { + @xmlFlattened + @xmlName("flatmap") + flatMap: StringMap, + + @xmlFlattened + @xmlName("flatlist") + flatList: StringList, + + @xmlName("double") + fpDouble: Double, + + struct: StructType, + + // TODO - enum lists, timestamp lists, structure list, structure map, multiple flat lists interspersed (xml only) } diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt new file mode 100644 index 000000000..2b8ed1792 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt @@ -0,0 +1,28 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer +import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer +import aws.smithy.kotlin.runtime.smithy.test.assertXmlStringsEqual +import kotlin.test.assertEquals + +abstract class AbstractXmlTest { + fun testRoundTrip( + expected: T, + payload: String, + serializerFn: (XmlSerializer, T) -> Unit, + deserializerFn: (XmlDeserializer) -> T, + ) { + val serializer = XmlSerializer() + serializerFn(serializer, expected) + val actualPayload = serializer.toByteArray().decodeToString() + assertXmlStringsEqual(payload, actualPayload) + + val deserializer = XmlDeserializer(payload.encodeToByteArray()) + val actualDeserialized = deserializerFn(deserializer) + assertEquals(expected, actualDeserialized) + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt new file mode 100644 index 000000000..8f2fb8573 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt @@ -0,0 +1,123 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import kotlin.test.Test + +class XmlListTest : AbstractXmlTest() { + @Test + fun testNormalList() { + val expected = StructType { + normalList = listOf("bar", "baz") + } + val payload = """ + + + bar + baz + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testSparseList() { + val expected = StructType { + sparseList = listOf("bar", null, "baz") + } + val payload = """ + + + bar + + baz + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedList() { + val expected = StructType { + nestedList = listOf( + listOf("a", "b", "c"), + listOf("x", "y", "z"), + ) + } + val payload = """ + + + + a + b + c + + + x + y + z + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testListWithRenamedMember() { + val expected = StructType { + renamedMemberList = listOf("bar", "baz") + } + val payload = """ + + + bar + baz + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatList() { + val expected = StructType { + flatList = listOf("foo", "bar") + } + val payload = """ + + foo + bar + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + // FIXME - re-enable after we implement fix + // @Test + // fun testDeserializeInterspersedSparseLists() { + // // see https://github.com/awslabs/aws-sdk-kotlin/issues/1220 + // val expected = StructType { + // flatList = listOf("foo", "bar") + // secondFlatList = listOf(1, 2) + // } + // val payload = """ + // + // foo + // 1 + // bar + // 2 + // + // """.trimIndent() + // val deserializer = XmlDeserializer(payload.encodeToByteArray()) + // val actualDeserialized = deserializeStructTypeDocument(deserializer) + // assertEquals(expected, actualDeserialized) + // } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt new file mode 100644 index 000000000..7214b9db8 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt @@ -0,0 +1,165 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.tests.serde.xml.model.FooEnum +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import kotlin.test.Test + +class XmlMapTest : AbstractXmlTest() { + @Test + fun testNormalMap() { + val expected = StructType { + normalMap = mapOf( + "foo" to "bar", + "baz" to "quux", + ) + } + val payload = """ + + + + foo + bar + + + baz + quux + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testSparseMap() { + val expected = StructType { + sparseMap = mapOf( + "foo" to "bar", + "null" to null, + "baz" to "quux", + ) + } + val payload = """ + + + + foo + bar + + + null + + + + baz + quux + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedMap() { + val expected = StructType { + nestedMap = mapOf( + "foo" to mapOf( + "k1" to "v1", + "k2" to "v2", + ), + "bar" to mapOf( + "k3" to "v3", + "k4" to "v4", + ), + ) + } + val payload = """ + + + + foo + + + k1 + v1 + + + k2 + v2 + + + + + bar + + + k3 + v3 + + + k4 + v4 + + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testMapWithRenamedMember() { + val expected = StructType { + renamedMemberMap = mapOf( + "foo" to "bar", + "baz" to "quux", + ) + } + val payload = """ + + + + foo + bar + + + baz + quux + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatMap() { + val expected = StructType { + flatEnumMap = mapOf( + "foo" to FooEnum.Foo, + "bar" to FooEnum.Bar, + ) + } + val payload = """ + + + foo + Foo + + + bar + Bar + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt index ed80a5062..8fa03295c 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt @@ -4,41 +4,111 @@ */ package aws.smithy.kotlin.tests.serde -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer -import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer +import aws.smithy.kotlin.runtime.content.BigInteger +import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String +import aws.smithy.kotlin.runtime.time.Instant import aws.smithy.kotlin.tests.serde.xml.model.FooEnum -import aws.smithy.kotlin.tests.serde.xml.model.Top -import aws.smithy.kotlin.tests.serde.xml.serde.deserializeTopDocument -import aws.smithy.kotlin.tests.serde.xml.serde.serializeTopDocument +import aws.smithy.kotlin.tests.serde.xml.model.IntegerEnum +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import java.math.BigDecimal import kotlin.test.Test -import kotlin.test.assertEquals -class XmlStructTest { +class XmlStructTest : AbstractXmlTest() { @Test fun testStructPrimitives() { - val expected = Top { + val expected = StructType { strField = "a string" + byteField = 2.toByte() + intField = 3 + shortField = 4 + longField = 5L + floatField = 6.0f + doubleField = 7.1 + bigIntegerField = BigInteger("1234") + bigDecimalField = BigDecimal("1.234") + boolField = true + blobField = "blob field".encodeToByteArray() enumField = FooEnum.Bar + intEnumField = IntegerEnum.C + dateTimeField = Instant.fromIso8601("2020-10-16T15:46:24.982Z") + epochTimeField = Instant.fromEpochSeconds(1657204347) + httpTimeField = Instant.fromRfc5322("Sat, 22 Jul 2017 19:30:00 GMT") extra = 42 } + val base64BlobField = expected.blobField!!.encodeBase64String() + val payload = """ - + a string + 2 + 3 + 4 + 5 + 6.0 + 7.1 + 1234 + 1.234 + true + $base64BlobField Bar - - """.trimIndent().encodeToByteArray() + 3 + 2020-10-16T15:46:24.982Z + 1657204347 + Sat, 22 Jul 2017 19:30:00 GMT + + """.trimIndent() + + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } - val serializer = XmlSerializer() - serializeTopDocument(serializer, expected) - val actualPayload = serializer.toByteArray().decodeToString() + @Test + fun testRenamedMembers() { + val expected = StructType { + renamedWithPrefix = "foo" + flatList = listOf("bar", "baz") + } + val payload = """ + + foo + bar + baz + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } - val deserializer = XmlDeserializer(payload) - val actualDeserialized = deserializeTopDocument(deserializer) - assertEquals(expected, actualDeserialized) + @Test + fun testRecursiveType() { + val expected = StructType { + strField = "first" + recursive { + strField = "second" + extra = 42 + recursive { + strField = "third" + normalList = listOf("foo", "bar") + } + } + } + val payload = """ + + first + + second + + third + + foo + bar + + + + + """.trimIndent() - // TODO - use assertXmlStringsEqual from smithy-test - // TODO - figure out roundtrip structure - // TODO - turn into abstract base for XML vs JSON + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) } } diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt new file mode 100644 index 000000000..5e3d6502a --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.model.UnionType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import kotlin.test.Test + +class XmlUnionTest : AbstractXmlTest() { + @Test + fun testString() { + val expected = StructType { + unionField = UnionType.StrField("a string") + } + val payload = """ + + + a string + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testByte() { + val expected = StructType { + unionField = UnionType.ByteField(1) + } + val payload = """ + + + 1 + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testInt() { + val expected = StructType { + unionField = UnionType.IntField(1) + } + val payload = """ + + + 1 + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testLong() { + val expected = StructType { + unionField = UnionType.LongField(1) + } + val payload = """ + + + 1 + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testTimestamp() { + val expected = StructType { + unionField = UnionType.DateTimeField( + Instant.fromIso8601("2020-10-16T15:46:24.982Z"), + ) + } + val payload = """ + + + 2020-10-16T15:46:24.982Z + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNormalList() { + val expected = StructType { + unionField = UnionType.NormalList(listOf("foo", "bar")) + } + + val payload = """ + + + + foo + bar + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNormalMap() { + val expected = StructType { + unionField = UnionType.NormalMap( + mapOf( + "k1" to "v1", + "k2" to "v2", + ), + ) + } + val payload = """ + + + + + k1 + v1 + + + k2 + v2 + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + // FIXME - https://github.com/awslabs/smithy-kotlin/issues/1040 + // @Test + // fun testUnitField() { } + + @Test + fun testStruct() { + val expected = StructType { + unionField = UnionType.Struct( + StructType { + normalMap = mapOf("k1" to "v1", "k2" to "v2") + strField = "a string" + }, + ) + } + val payload = """ + + + + + + k1 + v1 + + + k2 + v2 + + + a string + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } +} From 0f3fe6938c149b312a615404de44bab523689b74 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Thu, 22 Feb 2024 14:54:02 -0500 Subject: [PATCH 04/25] implement map and list deserialize --- .../codegen/core/AbstractCodeWriterExt.kt | 4 + .../kotlin/codegen/core/RuntimeTypes.kt | 13 + .../serde/DeserializeStructGenerator.kt | 2 - .../codegen/rendering/serde/SerdeExt.kt | 2 + .../rendering/serde/XmlParserGenerator.kt | 342 +++++++++++++++++- gradle/libs.versions.toml | 2 +- .../runtime/collections/CollectionExt.kt | 11 +- .../kotlin/runtime/serde/xml/TagReader.kt | 14 +- tests/codegen/serde-tests/build.gradle.kts | 7 + .../kotlin/tests/serde/AbstractXmlTest.kt | 27 +- 10 files changed, 385 insertions(+), 39 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt index 54bded579..2b943da04 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt @@ -187,3 +187,7 @@ fun > T.callIf(test: Boolean, runnable: Runnable): T { } return this } + +/** Escape the [expressionStart] character to avoid problems during formatting */ +fun > T.escape(text: String): String = + text.replace("$expressionStart", "$expressionStart$expressionStart") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index 36bc0cec8..9267e3351 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -103,6 +103,7 @@ object RuntimeTypes { val Attributes = symbol("Attributes") val attributesOf = symbol("attributesOf") val AttributeKey = symbol("AttributeKey") + val createOrAppend = symbol("createOrAppend") val get = symbol("get") val mutableMultiMapOf = symbol("mutableMultiMapOf") val putIfAbsent = symbol("putIfAbsent") @@ -262,6 +263,18 @@ object RuntimeTypes { val XmlSerializer = symbol("XmlSerializer") val XmlDeserializer = symbol("XmlDeserializer") val XmlUnwrappedOutput = symbol("XmlUnwrappedOutput") + + val TagReader = symbol("TagReader") + val xmlStreamReader = symbol("xmlStreamReader") + val root = symbol("root") + val text = symbol("text") + val readInt = symbol("readInt") + val readShort = symbol("readShort") + val readLong = symbol("readLong") + val readFloat = symbol("readFloat") + val readDouble = symbol("readDouble") + val readByte = symbol("readByte") + val readBoolean = symbol("readBoolean") } object SerdeFormUrl : RuntimeTypePackage(KotlinDependency.SERDE_FORM_URL) { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt index 5680af9f5..2f8e8de51 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt @@ -605,5 +605,3 @@ open class DeserializeStructGenerator( } } } - -private fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else "" diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt index 56ab30cd5..7f53e0755 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt @@ -289,3 +289,5 @@ internal fun Shape.childShape(model: Model): Shape? = when (this) { is MapShape -> model.expectShape(this.value.target) else -> null } + +internal fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else "" diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index da3d78448..cfddf16a5 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -5,19 +5,21 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.kotlin.codegen.core.KotlinWriter -import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes -import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.codegen.core.SymbolReference +import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde.SerdeXml +import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex -import software.amazon.smithy.kotlin.codegen.model.targetOrSelf import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext -import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.* +import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.model.traits.XmlFlattenedTrait +import software.amazon.smithy.model.traits.XmlNameTrait +import software.amazon.smithy.utils.StringUtils /** * XML parser generator based on common deserializer interface and XML serde descriptors @@ -28,6 +30,7 @@ open class XmlParserGenerator( private val defaultTimestampFormat: TimestampFormatTrait.Format, ) : StructuredDataParserGenerator { + // FIXME - remove open fun descriptorGenerator( ctx: ProtocolGenerator.GenerationContext, shape: Shape, @@ -73,7 +76,7 @@ open class XmlParserGenerator( documentMembers: List, writer: KotlinWriter, ) { - writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer) + writer.write("val reader = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) val shape = ctx.model.expectShape(op.output.get()) renderDeserializerBody(ctx, shape, documentMembers, writer) } @@ -84,23 +87,23 @@ open class XmlParserGenerator( members: List, writer: KotlinWriter, ) { - descriptorGenerator(ctx, shape, members, writer).render() if (shape.isUnionShape) { - val name = ctx.symbolProvider.toSymbol(shape).name - DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render() + // TODO - parse unions + // val name = ctx.symbolProvider.toSymbol(shape).name + // DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render() } else { - DeserializeStructGenerator(ctx, members, writer, defaultTimestampFormat).render() + deserializeStruct(ctx, shape, members, writer) } } - protected fun documentDeserializer( + private fun documentDeserializer( ctx: ProtocolGenerator.GenerationContext, shape: Shape, members: Collection = shape.members(), ): Symbol { val symbol = ctx.symbolProvider.toSymbol(shape) return shape.documentDeserializer(ctx.settings, symbol, members) { writer -> - writer.openBlock("internal fun #identifier.name:L(deserializer: #T): #T {", RuntimeTypes.Serde.Deserializer, symbol) + writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", SerdeXml.TagReader, symbol) .call { if (shape.isUnionShape) { writer.write("var value: #T? = null", symbol) @@ -128,7 +131,7 @@ open class XmlParserGenerator( val fnName = symbol.errorDeserializerName() writer.openBlock("internal fun #L(builder: #T.Builder, payload: ByteArray) {", fnName, symbol) .call { - writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer) + writer.write("val reader = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) renderDeserializerBody(ctx, errorShape, members, writer) } .closeBlock("}") @@ -152,10 +155,315 @@ open class XmlParserGenerator( // short circuit when the shape has no modeled members to deserialize write("return #T.Builder().build()", symbol) } else { - write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer) + write("val deserializer = #T(payload)", SerdeXml.XmlDeserializer) write("return #T(deserializer)", deserializeFn) } } } } + + private fun KotlinWriter.deserializeLoop( + ignoreUnexpected: Boolean = true, + block: KotlinWriter.() -> Unit, + ) { + withBlock("loop@while(true) {", "}") { + write("val curr = reader.nextTag() ?: break@loop") + withBlock("when(curr.startTag.name.tag) {", "}") { + block(this) + if (ignoreUnexpected) { + write("else -> {}") + } + } + write("curr.drop()") + } + } + private fun deserializeStruct( + ctx: ProtocolGenerator.GenerationContext, + shape: Shape, + members: List, + writer: KotlinWriter, + ) { + // TODO - split attribute members and non attribute members + // TODO - don't generate a parse loop if no attribute members + writer.deserializeLoop { + members.forEach { member -> + val name = member.getTrait()?.value ?: member.memberName + write("// ${member.memberName} ${escape(member.id.toString())}") + writeInline("#S -> builder.#L = ", name, member.defaultName()) + deserializeMember(ctx, member, writer) + } + } + } + + private fun deserializeMember( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + when (target.type) { + ShapeType.LIST, ShapeType.SET -> { + if (member.hasTrait()) { + deserializeFlatList(ctx, member, writer) + } else { + deserializeList(ctx, member, writer) + } + } + ShapeType.MAP -> { + if (member.hasTrait()) { + deserializeFlatMap(ctx, member, writer) + } else { + deserializeMap(ctx, member, writer) + } + } + ShapeType.STRUCTURE, ShapeType.UNION -> { + val deserializeFn = documentDeserializer(ctx, target) + writer.write("#T(curr)", deserializeFn) + } + else -> deserializePrimitiveMember(ctx, member, writer) + } + } + + // TODO - this could probably be moved to SerdeExt and commonized + + private fun Shape.shapeDeserializerDefinitionFile( + ctx: ProtocolGenerator.GenerationContext, + ): String { + val target = targetOrSelf(ctx.model) + val shapeName = StringUtils.capitalize(target.id.getName(ctx.service)) + return "${shapeName}ShapeDeserializer.kt" + } + private fun Shape.shapeDeserializer( + ctx: ProtocolGenerator.GenerationContext, + block: (fnName: String, writer: KotlinWriter) -> Unit, + ): Symbol { + val target = targetOrSelf(ctx.model) + val shapeName = StringUtils.capitalize(target.id.getName(ctx.service)) + val symbol = ctx.symbolProvider.toSymbol(this) + + val fnName = "deserialize${shapeName}Shape" + return buildSymbol { + name = fnName + namespace = ctx.settings.pkg.serde + definitionFile = shapeDeserializerDefinitionFile(ctx) + reference(symbol, SymbolReference.ContextOption.DECLARE) + renderBy = { + block(fnName, it) + } + } + } + + private fun deserializeShape( + ctx: ProtocolGenerator.GenerationContext, + shape: Shape, + block: KotlinWriter.() -> Unit, + ): Symbol { + val symbol = ctx.symbolProvider.toSymbol(shape) + val deserializeFn = shape.shapeDeserializer(ctx) { fnName, writer -> + writer.withBlock( + "internal fun #L(reader: #T): #T {", + "}", + fnName, + SerdeXml.TagReader, + symbol, + ) { + block(this) + } + } + return deserializeFn + } + + private fun deserializeList( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + val targetMember = target.member + val isSparse = target.hasTrait() + val deserializeFn = deserializeShape(ctx, target) { + write("val result = mutableListOf<#T#L>()", ctx.symbolProvider.toSymbol(targetMember), nullabilitySuffix(isSparse)) + deserializeLoop { + val memberName = targetMember.getTrait()?.value ?: targetMember.memberName + withBlock("#S -> {", "}", memberName) { + deserializeListInner(ctx, target, this) + write("result.add(el)") + } + } + write("return result") + } + writer.write("#T(curr)", deserializeFn) + } + + private fun deserializeFlatList( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + writer.withBlock("run {", "}") { + deserializeListInner(ctx, target, this) + write("#T(builder.#L, el)", RuntimeTypes.Core.Collections.createOrAppend, member.defaultName()) + } + } + + private fun deserializeListInner( + ctx: ProtocolGenerator.GenerationContext, + target: CollectionShape, + writer: KotlinWriter, + ) { + // <- sparse + // CDATA || TAG(s) <- not sparse + val isSparse = target.hasTrait() + with(writer) { + if (isSparse) { + openBlock("val el = if (curr.nextHasValue()) {") + .call { + deserializeMember(ctx, target.member, this) + } + .closeAndOpenBlock("} else {") + .write("null") + .closeBlock("}") + } else { + writeInline("val el = ") + deserializeMember(ctx, target.member, this) + } + } + } + + private fun deserializeMap( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + val keySymbol = ctx.symbolProvider.toSymbol(target.key) + val valueSymbol = ctx.symbolProvider.toSymbol(target.value) + val isSparse = target.hasTrait() + + val deserializeFn = deserializeShape(ctx, target) { + write("val result = mutableMapOf<#T, #T#L>()", keySymbol, valueSymbol, nullabilitySuffix(isSparse)) + deserializeLoop { + withBlock("#S -> {", "}", "entry") { + val deserializeEntryFn = deserializeMapEntry(ctx, target) + write("#T(result, curr)", deserializeEntryFn) + } + } + write("return result") + } + writer.write("#T(curr)", deserializeFn) + } + private fun deserializeFlatMap( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + val keySymbol = ctx.symbolProvider.toSymbol(target.key) + val valueSymbol = ctx.symbolProvider.toSymbol(target.value) + val isSparse = target.hasTrait() + writer.withBlock("run {", "}") { + write( + "val dest = builder.#L?.toMutableMap() ?: mutableMapOf<#T, #T#L>()", + member.defaultName(), + keySymbol, + valueSymbol, + nullabilitySuffix(isSparse), + ) + val deserializeEntryFn = deserializeMapEntry(ctx, target) + write("#T(dest, curr)", deserializeEntryFn) + write("dest") + } + } + + private fun deserializeMapEntry( + ctx: ProtocolGenerator.GenerationContext, + map: MapShape, + ): Symbol { + val shapeName = StringUtils.capitalize(map.id.getName(ctx.service)) + val keySymbol = ctx.symbolProvider.toSymbol(map.key) + val valueSymbol = ctx.symbolProvider.toSymbol(map.value) + val isSparse = map.hasTrait() + + return buildSymbol { + name = "deserialize${shapeName}Entry" + namespace = ctx.settings.pkg.serde + definitionFile = map.shapeDeserializerDefinitionFile(ctx) + renderBy = { writer -> + // NOTE: we make this internal rather than private because flat maps don't generate a + // dedicated map deserializer, they inline the entry deserialization since the map + // being built up is not processed all at once + writer.withBlock( + "internal fun $name(dest: MutableMap<#T, #T#L>, reader: #T) {", + "}", + keySymbol, + valueSymbol, + nullabilitySuffix(isSparse), + SerdeXml.TagReader, + ) { + write("var key: #T? = null", keySymbol) + write("var value: #T? = null", valueSymbol) + deserializeLoop { + val keyName = map.key.getTrait()?.value ?: map.key.memberName + writeInline("#S -> key = ", keyName) + deserializeMember(ctx, map.key, this) + + val valueName = map.value.getTrait()?.value ?: map.value.memberName + if (isSparse) { + openBlock("#S -> value = if (curr.nextHasValue()) {", valueName) + .call { + deserializeMember(ctx, map.value, this) + } + .closeAndOpenBlock("} else {") + .write("null") + .closeBlock("}") + } else { + writeInline("#S -> value = ", valueName) + deserializeMember(ctx, map.value, this) + } + } + write("if (key == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing key map entry") + write("if (value == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing value map entry") + write("dest[key] = value") + } + } + } + } + + private fun deserializePrimitiveMember( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + when (target.type) { + ShapeType.BLOB -> writer.write("curr.#T().#T()", SerdeXml.text, RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes) + ShapeType.BOOLEAN -> writer.write("curr.#T()", SerdeXml.readBoolean) + ShapeType.STRING -> writer.write("curr.#T()", SerdeXml.text) + ShapeType.TIMESTAMP -> { + val trait = member.getTrait() ?: target.getTrait() + val tsFormat = trait?.format ?: defaultTimestampFormat + + // FIXME - reconcile with utility function that already exists + val fromFn = when (tsFormat) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "fromEpochSeconds" + TimestampFormatTrait.Format.DATE_TIME -> "fromIso8601" + TimestampFormatTrait.Format.HTTP_DATE -> "fromRfc5322" + else -> throw CodegenException("unknown timestamp format: $tsFormat") + } + writer.write("#T.#L(curr.#T())", RuntimeTypes.Core.Instant, fromFn, SerdeXml.text) + } + ShapeType.BYTE -> writer.write("curr.#T()", SerdeXml.readByte) + ShapeType.SHORT -> writer.write("curr.#T()", SerdeXml.readShort) + ShapeType.INTEGER -> writer.write("curr.#T()", SerdeXml.readInt) + ShapeType.LONG -> writer.write("curr.#T()", SerdeXml.readLong) + ShapeType.FLOAT -> writer.write("curr.#T()", SerdeXml.readFloat) + ShapeType.DOUBLE -> writer.write("curr.#T()", SerdeXml.readDouble) + ShapeType.BIG_DECIMAL -> writer.write("#T(curr.#T())", RuntimeTypes.Core.Content.BigDecimal, SerdeXml.text) + ShapeType.BIG_INTEGER -> writer.write("#T(curr.#T())", RuntimeTypes.Core.Content.BigInteger, SerdeXml.text) + ShapeType.ENUM -> writer.write("#T.fromValue(curr.#T())", ctx.symbolProvider.toSymbol(target), SerdeXml.text) + ShapeType.INT_ENUM -> writer.write("#T.fromValue(curr.#T())", ctx.symbolProvider.toSymbol(target), SerdeXml.readInt) + else -> error("unknown primitive member shape $member") + } + } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 87e9b9cc3..eea8dd4af 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,5 +1,5 @@ [versions] -kotlin-version = "1.9.21" +kotlin-version = "1.9.22" dokka-version = "1.9.10" aws-kotlin-repo-tools-version = "0.4.0" diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt index e57fefbee..c355bb720 100644 --- a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt @@ -8,10 +8,11 @@ package aws.smithy.kotlin.runtime.collections * Creates a new list or appends to an existing one if not null. * * If [dest] is null this function creates a new list with element [x] and returns it. - * Otherwise, it appends [x] to [dest] and returns the given [dest] list. + * Otherwise, it appends [x] to [dest] and returns the mutated list. */ -public fun createOrAppend(dest: MutableList?, x: T): MutableList { - if (dest == null) return mutableListOf(x) - dest.add(x) - return dest +public fun createOrAppend(dest: List?, x: T): List { + if (dest == null) return listOf(x) + val mut = dest.toMutableList() + mut.add(x) + return mut } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt index 60112b983..7ccebda99 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt @@ -33,6 +33,11 @@ public class TagReader( return reader.nextToken() } + public fun nextHasValue(): Boolean { + if (closed) return false + return reader.peek() !is XmlToken.EndElement + } + public fun skipNext() { if (closed) return reader.skipNext() @@ -89,15 +94,6 @@ public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): TagReader { return TagReader(this, reader) } -// @InternalApi -// public fun XmlToken.BeginElement.decode(reader: XmlStreamReader, block: TagReader.() -> T): T { -// val scoped = tagReader(reader) -// val result = block(scoped) -// // exhaust this reader -// scoped.drop() -// return result -// } - /** * Consume the next token and map the data value from it using [transform] * diff --git a/tests/codegen/serde-tests/build.gradle.kts b/tests/codegen/serde-tests/build.gradle.kts index 36e94a306..e963e664e 100644 --- a/tests/codegen/serde-tests/build.gradle.kts +++ b/tests/codegen/serde-tests/build.gradle.kts @@ -23,6 +23,10 @@ dependencies { tasks.generateSmithyProjections { smithyBuildConfigs.set(files("smithy-build.json")) inputs.dir(project.layout.projectDirectory.dir("model")) + listOf("xml", "json").forEach { projectionName -> + val fromDir = smithyBuild.smithyKotlinProjectionSrcDir(projectionName) + outputs.dir(fromDir) + } buildClasspath.set(codegen) } @@ -58,6 +62,9 @@ val stageGeneratedSources = tasks.register("stageGeneratedSources") { group = "codegen" dependsOn(tasks.generateSmithyProjections) outputs.dir(generatedSrcDir) + // FIXME - this task up-to-date checks are wrong, likely something is not setup right with inputs/outputs somewhere + // for now just always run it + outputs.upToDateWhen { false } doLast { listOf("xml", "json").forEach { projectionName -> val fromDir = smithyBuild.smithyKotlinProjectionSrcDir(projectionName) diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt index 2b8ed1792..89df7fee3 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt @@ -4,25 +4,42 @@ */ package aws.smithy.kotlin.tests.serde -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer -import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer +import aws.smithy.kotlin.runtime.serde.xml.* import aws.smithy.kotlin.runtime.smithy.test.assertXmlStringsEqual import kotlin.test.assertEquals abstract class AbstractXmlTest { + // // FIXME - remove before merge - this test suite was put in place prior to changing the implementation to + // // verify everything works + // fun testRoundTrip( + // expected: T, + // payload: String, + // serializerFn: (XmlSerializer, T) -> Unit, + // deserializerFn: (XmlDeserializer) -> T, + // ) { + // val serializer = XmlSerializer() + // serializerFn(serializer, expected) + // val actualPayload = serializer.toByteArray().decodeToString() + // assertXmlStringsEqual(payload, actualPayload) + // + // val deserializer = XmlDeserializer(payload.encodeToByteArray()) + // val actualDeserialized = deserializerFn(deserializer) + // assertEquals(expected, actualDeserialized) + // } + fun testRoundTrip( expected: T, payload: String, serializerFn: (XmlSerializer, T) -> Unit, - deserializerFn: (XmlDeserializer) -> T, + deserializerFn: (TagReader) -> T, ) { val serializer = XmlSerializer() serializerFn(serializer, expected) val actualPayload = serializer.toByteArray().decodeToString() assertXmlStringsEqual(payload, actualPayload) - val deserializer = XmlDeserializer(payload.encodeToByteArray()) - val actualDeserialized = deserializerFn(deserializer) + val reader = xmlStreamReader(payload.encodeToByteArray()).root() + val actualDeserialized = deserializerFn(reader) assertEquals(expected, actualDeserialized) } } From 130d2158892dcb4b6dd94eaa8d7cee8cbc5c88eb Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Thu, 22 Feb 2024 21:56:26 -0500 Subject: [PATCH 05/25] refactor to use result --- .../kotlin/codegen/core/RuntimeTypes.kt | 23 ++-- .../protocol/HttpBindingProtocolGenerator.kt | 19 ++- .../codegen/rendering/serde/SerdeExt.kt | 33 +++-- .../serde/SerializeStructGenerator.kt | 7 -- .../rendering/serde/XmlParserGenerator.kt | 115 +++++++++++++----- .../smithy/kotlin/runtime/util/ResultExt.kt | 15 +++ .../smithy/kotlin/runtime/serde/Exceptions.kt | 11 ++ .../smithy/kotlin/runtime/serde/Parsers.kt | 86 +++++++++++++ .../kotlin/runtime/serde/xml/TagReader.kt | 50 +++++--- .../kotlin/runtime/serde/xml/XmlSerializer.kt | 2 +- .../runtime/serde/xml/XmlStreamWriter.kt | 2 +- .../kotlin/runtime/serde/xml/XmlToken.kt | 3 + .../kotlin/runtime/serde/xml/TagReaderTest.kt | 4 +- .../runtime/serde/xml/XmlStreamWriterTest.kt | 6 +- 14 files changed, 291 insertions(+), 85 deletions(-) create mode 100644 runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt create mode 100644 runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index 9267e3351..bc2026eda 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -231,6 +231,7 @@ object RuntimeTypes { val SerialKind = symbol("SerialKind") val SerializationException = symbol("SerializationException") val DeserializationException = symbol("DeserializationException") + val getOrDeserializeErr = symbol("getOrDeserializeErr") val serializeStruct = symbol("serializeStruct") val serializeList = symbol("serializeList") @@ -242,6 +243,18 @@ object RuntimeTypes { val asSdkSerializable = symbol("asSdkSerializable") val field = symbol("field") + val parse = symbol("parse") + val parseInt = symbol("parseInt") + val parseShort = symbol("parseShort") + val parseLong = symbol("parseLong") + val parseFloat = symbol("parseFloat") + val parseDouble = symbol("parseDouble") + val parseByte = symbol("parseByte") + val parseBoolean = symbol("parseBoolean") + val parseTimestamp = symbol("parseTimestamp") + val parseBigInteger = symbol("parseBigInteger") + val parseBigDecimal = symbol("parseBigDecimal") + object SerdeJson : RuntimeTypePackage(KotlinDependency.SERDE_JSON) { val JsonSerialName = symbol("JsonSerialName") val JsonSerializer = symbol("JsonSerializer") @@ -267,14 +280,8 @@ object RuntimeTypes { val TagReader = symbol("TagReader") val xmlStreamReader = symbol("xmlStreamReader") val root = symbol("root") - val text = symbol("text") - val readInt = symbol("readInt") - val readShort = symbol("readShort") - val readLong = symbol("readLong") - val readFloat = symbol("readFloat") - val readDouble = symbol("readDouble") - val readByte = symbol("readByte") - val readBoolean = symbol("readBoolean") + val data = symbol("data") + val tryData = symbol("tryData") } object SerdeFormUrl : RuntimeTypePackage(KotlinDependency.SERDE_FORM_URL) { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt index 9eed98597..e74474ba3 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt @@ -14,7 +14,7 @@ import software.amazon.smithy.kotlin.codegen.lang.toEscapedLiteral import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.rendering.serde.deserializerName import software.amazon.smithy.kotlin.codegen.rendering.serde.formatInstant -import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstant +import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstantExpr import software.amazon.smithy.kotlin.codegen.rendering.serde.serializerName import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.model.Model @@ -813,14 +813,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { HttpBinding.Location.HEADER, defaultTimestampFormat, ) - writer - .addImport(RuntimeTypes.Core.Instant) - .write( - "builder.#L = response.headers[#S]?.let { #L }", - memberName, - headerName, - parseInstant("it", tsFormat), - ) + writer.write( + "builder.#L = response.headers[#S]?.let { #L }", + memberName, + headerName, + writer.parseInstantExpr("it", tsFormat), + ) } is ListShape -> { // member > boolean, number, string, or timestamp @@ -849,8 +847,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { if (tsFormat == TimestampFormatTrait.Format.HTTP_DATE) { splitFn = "splitHttpDateHeaderListValues" } - writer.addImport(RuntimeTypes.Core.Instant) - parseInstant("it", tsFormat) + writer.parseInstantExpr("it", tsFormat) } is StringShape -> { when { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt index 7f53e0755..036cb3d79 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt @@ -9,8 +9,7 @@ import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.KotlinSettings -import software.amazon.smithy.kotlin.codegen.core.SymbolRenderer -import software.amazon.smithy.kotlin.codegen.core.defaultName +import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.core.mangledSuffix import software.amazon.smithy.kotlin.codegen.model.buildSymbol import software.amazon.smithy.model.Model @@ -216,11 +215,31 @@ fun formatInstant(paramName: String, tsFmt: TimestampFormatTrait.Format, forceSt * @param paramName The name of the local identifier to convert to an `Instant` * @param tsFmt The timestamp format [paramName] is expected to be converted from */ -fun parseInstant(paramName: String, tsFmt: TimestampFormatTrait.Format): String = when (tsFmt) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "Instant.fromEpochSeconds($paramName)" - TimestampFormatTrait.Format.DATE_TIME -> "Instant.fromIso8601($paramName)" - TimestampFormatTrait.Format.HTTP_DATE -> "Instant.fromRfc5322($paramName)" - else -> throw CodegenException("unknown timestamp format: $tsFmt") +fun KotlinWriter.parseInstantExpr(paramName: String, tsFmt: TimestampFormatTrait.Format): String { + val fn = when (tsFmt) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "fromEpochSeconds" + TimestampFormatTrait.Format.DATE_TIME -> "fromIso8601" + TimestampFormatTrait.Format.HTTP_DATE -> "fromRfc5322" + else -> throw CodegenException("unknown timestamp format: $tsFmt") + } + return format("#T.#L(#L)", RuntimeTypes.Core.Instant, fn, paramName) +} + +fun TimestampFormatTrait.Format.toRuntimeEnum(): String = when (this) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "TimestampFormat.EPOCH_SECONDS" + TimestampFormatTrait.Format.DATE_TIME -> "TimestampFormat.ISO_8601" + TimestampFormatTrait.Format.HTTP_DATE -> "TimestampFormat.RFC_5322" + else -> throw CodegenException("unknown timestamp format: $this") +} + +fun TimestampFormatTrait.Format.toRuntimeEnum(writer: KotlinWriter): String { + val enum = when (this) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "EPOCH_SECONDS" + TimestampFormatTrait.Format.DATE_TIME -> "ISO_8601" + TimestampFormatTrait.Format.HTTP_DATE -> "RFC_5322" + TimestampFormatTrait.Format.UNKNOWN -> error("unknown timestamp format trait") + } + return writer.format("#T.#L", RuntimeTypes.Core.TimestampFormat, enum) } /** diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt index 017f612bd..ff099b0c9 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt @@ -680,10 +680,3 @@ open class SerializeStructGenerator( return "serialize$suffix" } } - -fun TimestampFormatTrait.Format.toRuntimeEnum(): String = when (this) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "TimestampFormat.EPOCH_SECONDS" - TimestampFormatTrait.Format.DATE_TIME -> "TimestampFormat.ISO_8601" - TimestampFormatTrait.Format.HTTP_DATE -> "TimestampFormat.RFC_5322" - else -> throw CodegenException("unknown timestamp format: $this") -} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index cfddf16a5..8ac13c0e8 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -5,10 +5,10 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde -import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde.SerdeXml import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex @@ -17,6 +17,7 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingConte import software.amazon.smithy.model.shapes.* import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.model.traits.XmlAttributeTrait import software.amazon.smithy.model.traits.XmlFlattenedTrait import software.amazon.smithy.model.traits.XmlNameTrait import software.amazon.smithy.utils.StringUtils @@ -92,7 +93,7 @@ open class XmlParserGenerator( // val name = ctx.symbolProvider.toSymbol(shape).name // DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render() } else { - deserializeStruct(ctx, shape, members, writer) + deserializeStruct(ctx, members, writer) } } @@ -179,14 +180,21 @@ open class XmlParserGenerator( } private fun deserializeStruct( ctx: ProtocolGenerator.GenerationContext, - shape: Shape, members: List, writer: KotlinWriter, ) { - // TODO - split attribute members and non attribute members - // TODO - don't generate a parse loop if no attribute members + // split attribute members and non attribute members + val attributeMembers = members.filter { it.hasTrait() } + attributeMembers.forEach { member -> + deserializeAttributeMember(ctx, member, writer) + } + + val payloadMembers = members.filterNot { it.hasTrait() } + // don't generate a parse loop if no attribute members + if (payloadMembers.isEmpty()) return + writer.write("") writer.deserializeLoop { - members.forEach { member -> + payloadMembers.forEach { member -> val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") writeInline("#S -> builder.#L = ", name, member.defaultName()) @@ -195,6 +203,22 @@ open class XmlParserGenerator( } } + private fun deserializeAttributeMember( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter, + ) { + val memberName = member.getTrait()?.value ?: member.memberName + writer.withBlock( + "reader.startTag.getAttr(#S)?.let {", + "}", + memberName, + ) { + writeInline("builder.#L = ", member.defaultName()) + deserializePrimitiveMember(ctx, member, "it", textExprIsResult = false, this) + } + } + private fun deserializeMember( ctx: ProtocolGenerator.GenerationContext, member: MemberShape, @@ -220,7 +244,13 @@ open class XmlParserGenerator( val deserializeFn = documentDeserializer(ctx, target) writer.write("#T(curr)", deserializeFn) } - else -> deserializePrimitiveMember(ctx, member, writer) + else -> deserializePrimitiveMember( + ctx, + member, + writer.format("curr.#T()", SerdeXml.tryData), + textExprIsResult = true, + writer, + ) } } @@ -423,7 +453,9 @@ open class XmlParserGenerator( } } write("if (key == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing key map entry") - write("if (value == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing value map entry") + if (!isSparse) { + write("if (value == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing value map entry") + } write("dest[key] = value") } } @@ -433,37 +465,60 @@ open class XmlParserGenerator( private fun deserializePrimitiveMember( ctx: ProtocolGenerator.GenerationContext, member: MemberShape, + textExpr: String, + textExprIsResult: Boolean, writer: KotlinWriter, ) { val target = ctx.model.expectShape(member.target) - when (target.type) { - ShapeType.BLOB -> writer.write("curr.#T().#T()", SerdeXml.text, RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes) - ShapeType.BOOLEAN -> writer.write("curr.#T()", SerdeXml.readBoolean) - ShapeType.STRING -> writer.write("curr.#T()", SerdeXml.text) + + val parseFn = when (target.type) { + ShapeType.BLOB -> writer.format("#T { it.#T() } ", Serde.parse, RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes) + ShapeType.BOOLEAN -> writer.format("#T()", Serde.parseBoolean) + ShapeType.STRING -> { + if (!textExprIsResult) { + writer.write(textExpr) + return + } else { + null + } + } ShapeType.TIMESTAMP -> { val trait = member.getTrait() ?: target.getTrait() val tsFormat = trait?.format ?: defaultTimestampFormat - - // FIXME - reconcile with utility function that already exists - val fromFn = when (tsFormat) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "fromEpochSeconds" - TimestampFormatTrait.Format.DATE_TIME -> "fromIso8601" - TimestampFormatTrait.Format.HTTP_DATE -> "fromRfc5322" - else -> throw CodegenException("unknown timestamp format: $tsFormat") + // val fromArg = writer.format("curr.#T()") + // val fmtExpr = writer.parseInstantExpr(fromArg, tsFormat) + // writer.write(fmtExpr) + val runtimeEnum = tsFormat.toRuntimeEnum(writer) + writer.format("#T(#L)", Serde.parseTimestamp, runtimeEnum) + } + ShapeType.BYTE -> writer.format("#T()", Serde.parseByte) + ShapeType.SHORT -> writer.format("#T()", Serde.parseShort) + ShapeType.INTEGER -> writer.format("#T()", Serde.parseInt) + ShapeType.LONG -> writer.format("#T()", Serde.parseLong) + ShapeType.FLOAT -> writer.format("#T()", Serde.parseFloat) + ShapeType.DOUBLE -> writer.format("#T()", Serde.parseDouble) + ShapeType.BIG_DECIMAL -> writer.format("#T()", Serde.parseBigDecimal) + ShapeType.BIG_INTEGER -> writer.format("#T()", Serde.parseBigInteger) + ShapeType.ENUM -> { + if (!textExprIsResult) { + writer.write("#T.fromValue(#L)", ctx.symbolProvider.toSymbol(target), textExpr) + return } - writer.write("#T.#L(curr.#T())", RuntimeTypes.Core.Instant, fromFn, SerdeXml.text) + writer.format("#T { #T.fromValue(it) } ", Serde.parse, ctx.symbolProvider.toSymbol(target)) + } + ShapeType.INT_ENUM -> { + writer.format("#T { #T.fromValue(it.toInt()) } ", Serde.parse, ctx.symbolProvider.toSymbol(target)) } - ShapeType.BYTE -> writer.write("curr.#T()", SerdeXml.readByte) - ShapeType.SHORT -> writer.write("curr.#T()", SerdeXml.readShort) - ShapeType.INTEGER -> writer.write("curr.#T()", SerdeXml.readInt) - ShapeType.LONG -> writer.write("curr.#T()", SerdeXml.readLong) - ShapeType.FLOAT -> writer.write("curr.#T()", SerdeXml.readFloat) - ShapeType.DOUBLE -> writer.write("curr.#T()", SerdeXml.readDouble) - ShapeType.BIG_DECIMAL -> writer.write("#T(curr.#T())", RuntimeTypes.Core.Content.BigDecimal, SerdeXml.text) - ShapeType.BIG_INTEGER -> writer.write("#T(curr.#T())", RuntimeTypes.Core.Content.BigInteger, SerdeXml.text) - ShapeType.ENUM -> writer.write("#T.fromValue(curr.#T())", ctx.symbolProvider.toSymbol(target), SerdeXml.text) - ShapeType.INT_ENUM -> writer.write("#T.fromValue(curr.#T())", ctx.symbolProvider.toSymbol(target), SerdeXml.readInt) else -> error("unknown primitive member shape $member") } + + val escapedErrMessage = "expected $target".replace("$", "$$") + writer.write(textExpr) + .indent() + .callIf(parseFn != null) { + writer.write(".#L", parseFn) + } + .write(".#T { #S }", Serde.getOrDeserializeErr, escapedErrMessage) + .dedent() } } diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt new file mode 100644 index 000000000..6ba8947d3 --- /dev/null +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt @@ -0,0 +1,15 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.util + +/** + * Maps the exception to a new error if this instance represents [failure][Result.isFailure], leaving + * a [success][Result.isSuccess] value untouched. + */ +public inline fun Result.mapErr(onFailure: (Throwable) -> Throwable): Result = + when (val ex = exceptionOrNull()) { + null -> this + else -> Result.failure(onFailure(ex)) + } diff --git a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt index e774a49e6..18f908727 100644 --- a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt +++ b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt @@ -5,6 +5,8 @@ package aws.smithy.kotlin.runtime.serde import aws.smithy.kotlin.runtime.ClientException +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.util.mapErr /** * Exception class for all serialization errors @@ -33,3 +35,12 @@ public class DeserializationException : ClientException { public constructor(cause: Throwable?) : super(cause) } + +/** + * Get the underlying [success][Result.isSuccess] value or wrap the failure in a [DeserializationException] + * and throw it. + */ +@InternalApi +public inline fun Result.getOrDeserializeErr(errorMessage: () -> String): T = + mapErr { DeserializationException(errorMessage(), it) } + .getOrThrow() diff --git a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt new file mode 100644 index 000000000..7c86d51ea --- /dev/null +++ b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt @@ -0,0 +1,86 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.content.BigDecimal +import aws.smithy.kotlin.runtime.content.BigInteger +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.time.TimestampFormat + +@InternalApi +public inline fun String.parse(transform: (String) -> T): Result = runCatching { transform(this) } + +@InternalApi +public fun String.parseBoolean(): Result = parse(String::toBoolean) + +@InternalApi +public fun String.parseInt(): Result = parse(String::toInt) + +@InternalApi +public fun String.parseShort(): Result = parse(String::toShort) + +@InternalApi +public fun String.parseLong(): Result = parse(String::toLong) + +@InternalApi +public fun String.parseFloat(): Result = parse(String::toFloat) + +@InternalApi +public fun String.parseDouble(): Result = parse(String::toDouble) + +@InternalApi +public fun String.parseByte(): Result = parse { it.toInt().toByte() } + +public fun String.parseBigInteger(): Result = parse(::BigInteger) + +@InternalApi +public fun String.parseBigDecimal(): Result = parse(::BigDecimal) + +private fun String.toTimestamp(fmt: TimestampFormat): Instant = when (fmt) { + TimestampFormat.ISO_8601_CONDENSED, + TimestampFormat.ISO_8601_CONDENSED_DATE, + TimestampFormat.ISO_8601, + -> Instant.fromIso8601(this) + + TimestampFormat.RFC_5322 -> Instant.fromRfc5322(this) + TimestampFormat.EPOCH_SECONDS -> Instant.fromEpochSeconds(this) +} + +@InternalApi +public fun String.parseTimestamp(fmt: TimestampFormat): Result = parse { it.toTimestamp(fmt) } + +@InternalApi +public inline fun Result.parse(transform: (String) -> T): Result = mapCatching(transform) + +@InternalApi +public fun Result.parseBoolean(): Result = parse(String::toBoolean) + +@InternalApi +public fun Result.parseInt(): Result = parse(String::toInt) + +@InternalApi +public fun Result.parseShort(): Result = parse(String::toShort) + +@InternalApi +public fun Result.parseLong(): Result = parse(String::toLong) + +@InternalApi +public fun Result.parseFloat(): Result = parse(String::toFloat) + +@InternalApi +public fun Result.parseDouble(): Result = parse(String::toDouble) + +@InternalApi +public fun Result.parseByte(): Result = parse { it.toInt().toByte() } + +@InternalApi +public fun Result.parseBigInteger(): Result = parse(::BigInteger) + +@InternalApi +public fun Result.parseBigDecimal(): Result = parse(::BigDecimal) + +@InternalApi +public fun Result.parseTimestamp(fmt: TimestampFormat): Result = parse { it.toTimestamp(fmt) } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt index 7ccebda99..ad3a30cb0 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt @@ -79,6 +79,7 @@ public class TagReader( } } +@InternalApi public fun XmlStreamReader.root(): TagReader { val start = seek() ?: error("expected start tag: last = $lastToken") return start.tagReader(this) @@ -99,25 +100,44 @@ public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): TagReader { * * If the next token is not [XmlToken.Text] an exception will be thrown */ -public fun TagReader.map(transform: (String) -> T): T = - transform(text()) +@InternalApi +public inline fun TagReader.mapData(transform: (String) -> T): T = + transform(data()) -public fun TagReader.text(): String = +@InternalApi +public fun TagReader.data(): String = when (val next = nextToken()) { is XmlToken.Text -> next.value ?: "" null, is XmlToken.EndElement -> "" else -> throw DeserializationException("expected XmlToken.Text element, found $next") } -private fun TagReader.mapOrThrow(expected: String, mapper: (String) -> T?): T = - map { raw -> - mapper(raw) ?: throw DeserializationException("could not deserialize $raw as $expected for tag ${this.startTag}") - } - -public fun TagReader.readInt(): Int = mapOrThrow("Int", String::toIntOrNull) -public fun TagReader.readShort(): Short = mapOrThrow("Short", String::toShortOrNull) -public fun TagReader.readLong(): Long = mapOrThrow("Long", String::toLongOrNull) -public fun TagReader.readFloat(): Float = mapOrThrow("Float", String::toFloatOrNull) -public fun TagReader.readDouble(): Double = mapOrThrow("Double", String::toDoubleOrNull) -public fun TagReader.readByte(): Byte = mapOrThrow("Byte") { it.toIntOrNull()?.toByte() } -public fun TagReader.readBoolean(): Boolean = mapOrThrow("Boolean", String::toBoolean) +@InternalApi +public fun TagReader.tryData(): Result = runCatching { data() } + +// +// private fun TagReader.mapOrThrow(expected: String, mapper: (String) -> T?): T = +// map { raw -> +// mapper(raw) ?: throw DeserializationException("could not deserialize $raw as $expected for tag ${this.startTag}") +// } +// +// @InternalApi +// public fun TagReader.readInt(): Int = mapOrThrow("Int", String::toIntOrNull) +// +// @InternalApi +// public fun TagReader.readShort(): Short = mapOrThrow("Short", String::toShortOrNull) +// +// @InternalApi +// public fun TagReader.readLong(): Long = mapOrThrow("Long", String::toLongOrNull) +// +// @InternalApi +// public fun TagReader.readFloat(): Float = mapOrThrow("Float", String::toFloatOrNull) +// +// @InternalApi +// public fun TagReader.readDouble(): Double = mapOrThrow("Double", String::toDoubleOrNull) +// +// @InternalApi +// public fun TagReader.readByte(): Byte = mapOrThrow("Byte") { it.toIntOrNull()?.toByte() } +// +// @InternalApi +// public fun TagReader.readBoolean(): Boolean = mapOrThrow("Boolean", String::toBoolean) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt index 825993143..ff50a3d3c 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt @@ -184,7 +184,7 @@ public class XmlSerializer(private val xmlWriter: XmlStreamWriter = xmlStreamWri xmlWriter.text(value.toPlainString()) } - private fun serializeNumber(value: Number): Unit = xmlWriter.text(value) + private fun serializeNumber(value: Number): Unit = xmlWriter.data(value) override fun serializeString(value: String) { xmlWriter.text(value) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt index 1c529d912..a4f7b6176 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt @@ -78,7 +78,7 @@ public interface XmlStreamWriter { } @InternalApi -public fun XmlStreamWriter.text(text: Number) { +public fun XmlStreamWriter.data(text: Number) { this.text(text.toString()) } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt index ca1bfd8c8..4b1981a68 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt @@ -53,6 +53,9 @@ public sealed class XmlToken { public constructor(depth: Int, name: String, attributes: Map) : this(depth, QualifiedName(name), attributes) override fun toString(): String = "<${this.name} (${this.depth})>" + + // convenience function for codegen + public fun getAttr(qualified: String): String? = attributes[QualifiedName(qualified)] } /** diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt index 0ae02ea6d..411ec6a66 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt @@ -123,10 +123,10 @@ class TagReaderTest { assertEquals(2, curr.nextTag()?.readInt()) } "Child2" -> { - assertEquals("this is an a", curr.nextTag()?.text()) + assertEquals("this is an a", curr.nextTag()?.data()) // intentionally ignore the next tag and don't consume the entire child subtree } - "Child4" -> assertEquals(" ", curr.nextTag()?.text()) + "Child4" -> assertEquals(" ", curr.nextTag()?.data()) else -> {} } // consume the current tag entirely before trying to process the next diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt index 316f7034e..38ea083e0 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt @@ -167,7 +167,7 @@ fun writeMessage(writer: XmlStreamWriter, message: Message) { writer.apply { startTag("message") startTag("id") - text(message.id) + data(message.id) endTag("id") startTag("text") text(message.text) @@ -190,7 +190,7 @@ fun writeUser(writer: XmlStreamWriter, user: User) { writer.text(user.name) writer.endTag("name") writer.startTag("followers_count") - writer.text(user.followersCount) + writer.data(user.followersCount) writer.endTag("followers_count") writer.endTag("user") } @@ -200,7 +200,7 @@ fun writeDoublesArray(writer: XmlStreamWriter, doubles: Array?) { if (doubles != null) { for (value in doubles) { writer.startTag("position") - writer.text(value) + writer.data(value) writer.endTag("position") } } From ec5736eec08a3f9732b36184e60939083f5c188c Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Thu, 22 Feb 2024 22:19:25 -0500 Subject: [PATCH 06/25] almost there --- .../rendering/serde/XmlParserGenerator.kt | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index 8ac13c0e8..e194f417c 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -89,9 +89,7 @@ open class XmlParserGenerator( writer: KotlinWriter, ) { if (shape.isUnionShape) { - // TODO - parse unions - // val name = ctx.symbolProvider.toSymbol(shape).name - // DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render() + deserializeUnion(ctx, members, writer) } else { deserializeStruct(ctx, members, writer) } @@ -178,6 +176,28 @@ open class XmlParserGenerator( write("curr.drop()") } } + + private fun deserializeUnion( + ctx: ProtocolGenerator.GenerationContext, + members: List, + writer: KotlinWriter, + ) { + writer.deserializeLoop { + members.forEach { member -> + val name = member.getTrait()?.value ?: member.memberName + write("// ${member.memberName} ${escape(member.id.toString())}") + val unionTypeName = member.unionTypeName(ctx) + val unionVariantName = member.unionVariantName() + withBlock("#S -> value = #L(", ")", name, unionTypeName) { + // FIXME - need to propagate accumulator + // should be value?.as#LOrNull() + val accumFn = "value?.as${unionVariantName}OrNull()" + deserializeMember(ctx, member, writer) + } + } + } + } + private fun deserializeStruct( ctx: ProtocolGenerator.GenerationContext, members: List, From 1ce6139b4cc029aaff022727d10d5a1c8a7832f6 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 08:14:55 -0500 Subject: [PATCH 07/25] fix union deserialization of flat collections --- .../rendering/serde/XmlParserGenerator.kt | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index e194f417c..edd49bd58 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -173,6 +173,9 @@ open class XmlParserGenerator( write("else -> {}") } } + // maintain stream reader state by dropping the current element and all it's children + // this ensures nested elements with potentially the same name as a higher level element + // are not erroneously returned and matched by `nextTag()` write("curr.drop()") } } @@ -187,11 +190,7 @@ open class XmlParserGenerator( val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") val unionTypeName = member.unionTypeName(ctx) - val unionVariantName = member.unionVariantName() withBlock("#S -> value = #L(", ")", name, unionTypeName) { - // FIXME - need to propagate accumulator - // should be value?.as#LOrNull() - val accumFn = "value?.as${unionVariantName}OrNull()" deserializeMember(ctx, member, writer) } } @@ -345,6 +344,19 @@ open class XmlParserGenerator( writer.write("#T(curr)", deserializeFn) } + private fun flatCollectionAccumulatorExpr( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + ): String = + when (val container = ctx.model.expectShape(member.container)) { + is StructureShape -> "builder.${member.defaultName()}" + is UnionShape -> { + val unionVariantName = member.unionVariantName() + "value?.as${unionVariantName}OrNull()" + } + else -> error("unexpected container shape $container for member $member") + } + private fun deserializeFlatList( ctx: ProtocolGenerator.GenerationContext, member: MemberShape, @@ -353,7 +365,8 @@ open class XmlParserGenerator( val target = ctx.model.expectShape(member.target) writer.withBlock("run {", "}") { deserializeListInner(ctx, target, this) - write("#T(builder.#L, el)", RuntimeTypes.Core.Collections.createOrAppend, member.defaultName()) + val accum = flatCollectionAccumulatorExpr(ctx, member) + write("#T(#L, el)", RuntimeTypes.Core.Collections.createOrAppend, accum) } } @@ -413,9 +426,10 @@ open class XmlParserGenerator( val valueSymbol = ctx.symbolProvider.toSymbol(target.value) val isSparse = target.hasTrait() writer.withBlock("run {", "}") { + val accum = flatCollectionAccumulatorExpr(ctx, member) write( - "val dest = builder.#L?.toMutableMap() ?: mutableMapOf<#T, #T#L>()", - member.defaultName(), + "val dest = #L?.toMutableMap() ?: mutableMapOf<#T, #T#L>()", + accum, keySymbol, valueSymbol, nullabilitySuffix(isSparse), @@ -472,9 +486,9 @@ open class XmlParserGenerator( deserializeMember(ctx, map.value, this) } } - write("if (key == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing key map entry") + write("if (key == null) throw #T(#S)", Serde.DeserializationException, "missing key map entry") if (!isSparse) { - write("if (value == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing value map entry") + write("if (value == null) throw #T(#S)", Serde.DeserializationException, "missing value map entry") } write("dest[key] = value") } From d9628fad82cf6efe4f9ae3fff5008425e88e20db Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 08:21:02 -0500 Subject: [PATCH 08/25] enable interspersed flat tests --- .../smithy/kotlin/tests/serde/XmlListTest.kt | 44 +++++++------- .../smithy/kotlin/tests/serde/XmlMapTest.kt | 57 +++++++++++++++++++ 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt index 8f2fb8573..8a9f489cd 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt @@ -4,10 +4,13 @@ */ package aws.smithy.kotlin.tests.serde +import aws.smithy.kotlin.runtime.serde.xml.root +import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader import aws.smithy.kotlin.tests.serde.xml.model.StructType import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument import kotlin.test.Test +import kotlin.test.assertEquals class XmlListTest : AbstractXmlTest() { @Test @@ -100,24 +103,25 @@ class XmlListTest : AbstractXmlTest() { testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) } - // FIXME - re-enable after we implement fix - // @Test - // fun testDeserializeInterspersedSparseLists() { - // // see https://github.com/awslabs/aws-sdk-kotlin/issues/1220 - // val expected = StructType { - // flatList = listOf("foo", "bar") - // secondFlatList = listOf(1, 2) - // } - // val payload = """ - // - // foo - // 1 - // bar - // 2 - // - // """.trimIndent() - // val deserializer = XmlDeserializer(payload.encodeToByteArray()) - // val actualDeserialized = deserializeStructTypeDocument(deserializer) - // assertEquals(expected, actualDeserialized) - // } + @Test + fun testInterspersedFlatLists() { + // see https://github.com/awslabs/aws-sdk-kotlin/issues/1220 + val expected = StructType { + flatList = listOf("foo", "bar") + secondFlatList = listOf(1, 2) + } + val payload = """ + + foo + 1 + bar + 2 + + """.trimIndent() + + // we don't round trip this because the format isn't going to match + val reader = xmlStreamReader(payload.encodeToByteArray()).root() + val actualDeserialized = deserializeStructTypeDocument(reader) + assertEquals(expected, actualDeserialized) + } } diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt index 7214b9db8..c5fcb5a39 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt @@ -4,11 +4,15 @@ */ package aws.smithy.kotlin.tests.serde +import aws.smithy.kotlin.runtime.serde.xml.root +import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader import aws.smithy.kotlin.tests.serde.xml.model.FooEnum import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.model.UnionType import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument import kotlin.test.Test +import kotlin.test.assertEquals class XmlMapTest : AbstractXmlTest() { @Test @@ -162,4 +166,57 @@ class XmlMapTest : AbstractXmlTest() { """.trimIndent() testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) } + + @Test + fun testInterspersedFlatMaps() { + // see https://github.com/awslabs/aws-sdk-kotlin/issues/1220 + val expected = StructType { + flatEnumMap = mapOf( + "foo" to FooEnum.Foo, + "bar" to FooEnum.Bar, + ) + unionField = UnionType.Struct( + StructType { + normalMap = mapOf("k1" to "v1", "k2" to "v2") + flatEnumMap = mapOf("inner" to FooEnum.Baz) + }, + ) + } + val payload = """ + + + foo + Foo + + + + + + + k1 + v1 + + + k2 + v2 + + + + inner + Baz + + + + + bar + Bar + + + """.trimIndent() + + // we don't round trip this because the format isn't going to match + val reader = xmlStreamReader(payload.encodeToByteArray()).root() + val actualDeserialized = deserializeStructTypeDocument(reader) + assertEquals(expected, actualDeserialized) + } } From 29e0823068f0362faeb72d39b5c284de676e415c Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 08:30:52 -0500 Subject: [PATCH 09/25] enable more tests --- tests/codegen/serde-tests/model/shared.smithy | 6 +- tests/codegen/serde-tests/model/xml.smithy | 2 - .../smithy/kotlin/tests/serde/XmlUnionTest.kt | 129 ++++++++++++++++++ 3 files changed, 131 insertions(+), 6 deletions(-) diff --git a/tests/codegen/serde-tests/model/shared.smithy b/tests/codegen/serde-tests/model/shared.smithy index 87abe6612..5b89a7f2a 100644 --- a/tests/codegen/serde-tests/model/shared.smithy +++ b/tests/codegen/serde-tests/model/shared.smithy @@ -140,8 +140,7 @@ structure MapTypesMixin { union MapTypesUnionMixin { normalMap: StringMap, sparseMap: SparseStringMap, - // FIXME - doesn't work with current codegen for unions - // nestedMap: NestedStringMap, + nestedMap: NestedStringMap, } @mixin @@ -157,7 +156,6 @@ union ListTypesUnionMixin { sparseList: SparseStringList, - // FIXME - doesn't work with current codegen for unions - // nestedList: NestedStringList, + nestedList: NestedStringList, } diff --git a/tests/codegen/serde-tests/model/xml.smithy b/tests/codegen/serde-tests/model/xml.smithy index a360f6958..c06fd0ac2 100644 --- a/tests/codegen/serde-tests/model/xml.smithy +++ b/tests/codegen/serde-tests/model/xml.smithy @@ -83,6 +83,4 @@ union UnionType with [PrimitiveTypesUnionMixin, ListTypesUnionMixin, MapTypesUni fpDouble: Double, struct: StructType, - - // TODO - enum lists, timestamp lists, structure list, structure map, multiple flat lists interspersed (xml only) } diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt index 5e3d6502a..5ec3f5446 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt @@ -108,6 +108,55 @@ class XmlUnionTest : AbstractXmlTest() { testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) } + @Test + fun testFlatList() { + val expected = StructType { + unionField = UnionType.FlatList(listOf("foo", "bar")) + } + + val payload = """ + + + foo + bar + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedList() { + val expected = StructType { + unionField = UnionType.NestedList( + listOf( + listOf("a", "b", "c"), + listOf("x", "y", "z"), + ), + ) + } + + val payload = """ + + + + + a + b + c + + + x + y + z + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + @Test fun testNormalMap() { val expected = StructType { @@ -137,6 +186,86 @@ class XmlUnionTest : AbstractXmlTest() { testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) } + @Test + fun testNestedMap() { + val expected = StructType { + unionField = UnionType.NestedMap( + mapOf( + "foo" to mapOf( + "k1" to "v1", + "k2" to "v2", + ), + "bar" to mapOf( + "k3" to "v3", + "k4" to "v4", + ), + ), + ) + } + val payload = """ + + + + + foo + + + k1 + v1 + + + k2 + v2 + + + + + bar + + + k3 + v3 + + + k4 + v4 + + + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatMap() { + val expected = StructType { + unionField = UnionType.FlatMap( + mapOf( + "foo" to "bar", + "bar" to "baz", + ), + ) + } + val payload = """ + + + + foo + bar + + + bar + baz + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + // FIXME - https://github.com/awslabs/smithy-kotlin/issues/1040 // @Test // fun testUnitField() { } From cdb47067482627256642c8e3d43145a609f01c73 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 13:35:06 -0500 Subject: [PATCH 10/25] add hooks for unwrapping operation and error payloads and tracking correct tag to decode from --- .../smithy/kotlin/codegen/lang/KotlinTypes.kt | 1 + .../rendering/serde/XmlParserGenerator.kt | 158 +++++++++++++----- gradle.properties | 2 +- .../kotlin/tests/serde/AbstractXmlTest.kt | 18 -- 4 files changed, 119 insertions(+), 60 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt index 3f60e404c..fb967ccf7 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt @@ -41,6 +41,7 @@ object KotlinTypes { val List: Symbol = stdlibSymbol("List") val listOf: Symbol = stdlibSymbol("listOf") val MutableList: Symbol = stdlibSymbol("MutableList") + val MutableMap: Symbol = stdlibSymbol("MutableMap") val Map: Symbol = stdlibSymbol("Map") val mutableListOf: Symbol = stdlibSymbol("mutableListOf") val mutableMapOf: Symbol = stdlibSymbol("mutableMapOf") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index edd49bd58..bb5b12791 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -10,8 +10,10 @@ import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde.SerdeXml +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex +import software.amazon.smithy.kotlin.codegen.model.traits.UnwrappedXmlOutput import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext import software.amazon.smithy.model.shapes.* @@ -31,6 +33,14 @@ open class XmlParserGenerator( private val defaultTimestampFormat: TimestampFormatTrait.Format, ) : StructuredDataParserGenerator { + /** + * Deserialization context that holds current state + * @param tagReader the name of the current tag reader to operate on + */ + data class SerdeCtx( + val tagReader: String, + ) + // FIXME - remove open fun descriptorGenerator( ctx: ProtocolGenerator.GenerationContext, @@ -77,21 +87,39 @@ open class XmlParserGenerator( documentMembers: List, writer: KotlinWriter, ) { - writer.write("val reader = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) + writer.write("val root = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) val shape = ctx.model.expectShape(op.output.get()) - renderDeserializerBody(ctx, shape, documentMembers, writer) + val serdeCtx = unwrapOperationBody(ctx, SerdeCtx("root"), op, writer) + + if (op.hasTrait()) { + renderDeserializerUnwrappedXmlBody(ctx, serdeCtx, shape, writer) + } else { + renderDeserializerBody(ctx, serdeCtx, shape, documentMembers, writer) + } } + /** + * Hook for protocols to perform logic prior to deserializing the operation output. + * Implementations must return the [SerdeCtx] to use for further deserialization. + */ + protected open fun unwrapOperationBody( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + op: OperationShape, + writer: KotlinWriter, + ): SerdeCtx = serdeCtx + protected fun renderDeserializerBody( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, shape: Shape, members: List, writer: KotlinWriter, ) { if (shape.isUnionShape) { - deserializeUnion(ctx, members, writer) + deserializeUnion(ctx, serdeCtx, members, writer) } else { - deserializeStruct(ctx, members, writer) + deserializeStruct(ctx, serdeCtx, members, writer) } } @@ -104,13 +132,14 @@ open class XmlParserGenerator( return shape.documentDeserializer(ctx.settings, symbol, members) { writer -> writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", SerdeXml.TagReader, symbol) .call { + val serdeCtx = SerdeCtx("reader") if (shape.isUnionShape) { writer.write("var value: #T? = null", symbol) - renderDeserializerBody(ctx, shape, members.toList(), writer) - writer.write("return value ?: throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "Deserialized union value unexpectedly null: ${symbol.name}") + renderDeserializerBody(ctx, serdeCtx, shape, members.toList(), writer) + writer.write("return value ?: throw #T(#S)", Serde.DeserializationException, "Deserialized union value unexpectedly null: ${symbol.name}") } else { writer.write("val builder = #T.Builder()", symbol) - renderDeserializerBody(ctx, shape, members.toList(), writer) + renderDeserializerBody(ctx, serdeCtx, shape, members.toList(), writer) writer.write("builder.correctErrors()") writer.write("return builder.build()") } @@ -130,13 +159,25 @@ open class XmlParserGenerator( val fnName = symbol.errorDeserializerName() writer.openBlock("internal fun #L(builder: #T.Builder, payload: ByteArray) {", fnName, symbol) .call { - writer.write("val reader = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) - renderDeserializerBody(ctx, errorShape, members, writer) + writer.write("val root = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) + val serdeCtx = unwrapOperationError(ctx, SerdeCtx("root"), errorShape, writer) + renderDeserializerBody(ctx, serdeCtx, errorShape, members, writer) } .closeBlock("}") } } + /** + * Hook for protocols to perform logic prior to deserializing an operation error. + * Implementations must return the [SerdeCtx] to use for further deserialization. + */ + protected open fun unwrapOperationError( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + errorShape: StructureShape, + writer: KotlinWriter, + ): SerdeCtx = serdeCtx + override fun payloadDeserializer( ctx: ProtocolGenerator.GenerationContext, shape: Shape, @@ -154,21 +195,22 @@ open class XmlParserGenerator( // short circuit when the shape has no modeled members to deserialize write("return #T.Builder().build()", symbol) } else { - write("val deserializer = #T(payload)", SerdeXml.XmlDeserializer) - write("return #T(deserializer)", deserializeFn) + writer.write("val root = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) + write("return #T(root)", deserializeFn) } } } } private fun KotlinWriter.deserializeLoop( + serdeCtx: SerdeCtx, ignoreUnexpected: Boolean = true, - block: KotlinWriter.() -> Unit, + block: KotlinWriter.(SerdeCtx) -> Unit, ) { withBlock("loop@while(true) {", "}") { - write("val curr = reader.nextTag() ?: break@loop") + write("val curr = ${serdeCtx.tagReader}.nextTag() ?: break@loop") withBlock("when(curr.startTag.name.tag) {", "}") { - block(this) + block(this, serdeCtx.copy(tagReader = "curr")) if (ignoreUnexpected) { write("else -> {}") } @@ -182,16 +224,17 @@ open class XmlParserGenerator( private fun deserializeUnion( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, members: List, writer: KotlinWriter, ) { - writer.deserializeLoop { + writer.deserializeLoop(serdeCtx) { innerCtx -> members.forEach { member -> val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") val unionTypeName = member.unionTypeName(ctx) withBlock("#S -> value = #L(", ")", name, unionTypeName) { - deserializeMember(ctx, member, writer) + deserializeMember(ctx, innerCtx, member, writer) } } } @@ -199,37 +242,39 @@ open class XmlParserGenerator( private fun deserializeStruct( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, members: List, writer: KotlinWriter, ) { // split attribute members and non attribute members val attributeMembers = members.filter { it.hasTrait() } attributeMembers.forEach { member -> - deserializeAttributeMember(ctx, member, writer) + deserializeAttributeMember(ctx, serdeCtx, member, writer) } val payloadMembers = members.filterNot { it.hasTrait() } // don't generate a parse loop if no attribute members if (payloadMembers.isEmpty()) return writer.write("") - writer.deserializeLoop { + writer.deserializeLoop(serdeCtx) { innerCtx -> payloadMembers.forEach { member -> val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") writeInline("#S -> builder.#L = ", name, member.defaultName()) - deserializeMember(ctx, member, writer) + deserializeMember(ctx, innerCtx, member, writer) } } } private fun deserializeAttributeMember( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, member: MemberShape, writer: KotlinWriter, ) { val memberName = member.getTrait()?.value ?: member.memberName writer.withBlock( - "reader.startTag.getAttr(#S)?.let {", + "${serdeCtx.tagReader}.startTag.getAttr(#S)?.let {", "}", memberName, ) { @@ -240,6 +285,7 @@ open class XmlParserGenerator( private fun deserializeMember( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, member: MemberShape, writer: KotlinWriter, ) { @@ -247,16 +293,16 @@ open class XmlParserGenerator( when (target.type) { ShapeType.LIST, ShapeType.SET -> { if (member.hasTrait()) { - deserializeFlatList(ctx, member, writer) + deserializeFlatList(ctx, serdeCtx, member, writer) } else { - deserializeList(ctx, member, writer) + deserializeList(ctx, serdeCtx, member, writer) } } ShapeType.MAP -> { if (member.hasTrait()) { - deserializeFlatMap(ctx, member, writer) + deserializeFlatMap(ctx, serdeCtx, member, writer) } else { - deserializeMap(ctx, member, writer) + deserializeMap(ctx, serdeCtx, member, writer) } } ShapeType.STRUCTURE, ShapeType.UNION -> { @@ -324,6 +370,7 @@ open class XmlParserGenerator( private fun deserializeList( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, member: MemberShape, writer: KotlinWriter, ) { @@ -332,16 +379,16 @@ open class XmlParserGenerator( val isSparse = target.hasTrait() val deserializeFn = deserializeShape(ctx, target) { write("val result = mutableListOf<#T#L>()", ctx.symbolProvider.toSymbol(targetMember), nullabilitySuffix(isSparse)) - deserializeLoop { + deserializeLoop(SerdeCtx(tagReader = "reader")) { innerCtx -> val memberName = targetMember.getTrait()?.value ?: targetMember.memberName withBlock("#S -> {", "}", memberName) { - deserializeListInner(ctx, target, this) + deserializeListInner(ctx, innerCtx, target, this) write("result.add(el)") } } write("return result") } - writer.write("#T(curr)", deserializeFn) + writer.write("#T(${serdeCtx.tagReader})", deserializeFn) } private fun flatCollectionAccumulatorExpr( @@ -359,12 +406,13 @@ open class XmlParserGenerator( private fun deserializeFlatList( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, member: MemberShape, writer: KotlinWriter, ) { val target = ctx.model.expectShape(member.target) writer.withBlock("run {", "}") { - deserializeListInner(ctx, target, this) + deserializeListInner(ctx, serdeCtx, target, this) val accum = flatCollectionAccumulatorExpr(ctx, member) write("#T(#L, el)", RuntimeTypes.Core.Collections.createOrAppend, accum) } @@ -372,6 +420,7 @@ open class XmlParserGenerator( private fun deserializeListInner( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, target: CollectionShape, writer: KotlinWriter, ) { @@ -380,44 +429,47 @@ open class XmlParserGenerator( val isSparse = target.hasTrait() with(writer) { if (isSparse) { - openBlock("val el = if (curr.nextHasValue()) {") + openBlock("val el = if (${serdeCtx.tagReader}.nextHasValue()) {") .call { - deserializeMember(ctx, target.member, this) + deserializeMember(ctx, serdeCtx, target.member, this) } .closeAndOpenBlock("} else {") .write("null") .closeBlock("}") } else { writeInline("val el = ") - deserializeMember(ctx, target.member, this) + deserializeMember(ctx, serdeCtx, target.member, this) } } } private fun deserializeMap( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, member: MemberShape, writer: KotlinWriter, ) { val target = ctx.model.expectShape(member.target) val keySymbol = ctx.symbolProvider.toSymbol(target.key) val valueSymbol = ctx.symbolProvider.toSymbol(target.value) + writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) val isSparse = target.hasTrait() val deserializeFn = deserializeShape(ctx, target) { write("val result = mutableMapOf<#T, #T#L>()", keySymbol, valueSymbol, nullabilitySuffix(isSparse)) - deserializeLoop { + deserializeLoop(SerdeCtx("reader")) { innerCtx -> withBlock("#S -> {", "}", "entry") { val deserializeEntryFn = deserializeMapEntry(ctx, target) - write("#T(result, curr)", deserializeEntryFn) + write("#T(result, ${innerCtx.tagReader})", deserializeEntryFn) } } write("return result") } - writer.write("#T(curr)", deserializeFn) + writer.write("#T(${serdeCtx.tagReader})", deserializeFn) } private fun deserializeFlatMap( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, member: MemberShape, writer: KotlinWriter, ) { @@ -425,6 +477,7 @@ open class XmlParserGenerator( val keySymbol = ctx.symbolProvider.toSymbol(target.key) val valueSymbol = ctx.symbolProvider.toSymbol(target.value) val isSparse = target.hasTrait() + writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) writer.withBlock("run {", "}") { val accum = flatCollectionAccumulatorExpr(ctx, member) write( @@ -435,7 +488,7 @@ open class XmlParserGenerator( nullabilitySuffix(isSparse), ) val deserializeEntryFn = deserializeMapEntry(ctx, target) - write("#T(dest, curr)", deserializeEntryFn) + write("#T(dest, ${serdeCtx.tagReader})", deserializeEntryFn) write("dest") } } @@ -448,6 +501,7 @@ open class XmlParserGenerator( val keySymbol = ctx.symbolProvider.toSymbol(map.key) val valueSymbol = ctx.symbolProvider.toSymbol(map.value) val isSparse = map.hasTrait() + val serdeCtx = SerdeCtx("reader") return buildSymbol { name = "deserialize${shapeName}Entry" @@ -458,8 +512,9 @@ open class XmlParserGenerator( // dedicated map deserializer, they inline the entry deserialization since the map // being built up is not processed all at once writer.withBlock( - "internal fun $name(dest: MutableMap<#T, #T#L>, reader: #T) {", + "internal fun $name(dest: #T<#T, #T#L>, reader: #T) {", "}", + KotlinTypes.Collections.MutableMap, keySymbol, valueSymbol, nullabilitySuffix(isSparse), @@ -467,23 +522,24 @@ open class XmlParserGenerator( ) { write("var key: #T? = null", keySymbol) write("var value: #T? = null", valueSymbol) - deserializeLoop { + writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) + deserializeLoop(serdeCtx) { innerCtx -> val keyName = map.key.getTrait()?.value ?: map.key.memberName writeInline("#S -> key = ", keyName) - deserializeMember(ctx, map.key, this) + deserializeMember(ctx, innerCtx, map.key, this) val valueName = map.value.getTrait()?.value ?: map.value.memberName if (isSparse) { - openBlock("#S -> value = if (curr.nextHasValue()) {", valueName) + openBlock("#S -> value = if (${innerCtx.tagReader}.nextHasValue()) {", valueName) .call { - deserializeMember(ctx, map.value, this) + deserializeMember(ctx, innerCtx, map.value, this) } .closeAndOpenBlock("} else {") .write("null") .closeBlock("}") } else { writeInline("#S -> value = ", valueName) - deserializeMember(ctx, map.value, this) + deserializeMember(ctx, innerCtx, map.value, this) } } write("if (key == null) throw #T(#S)", Serde.DeserializationException, "missing key map entry") @@ -555,4 +611,24 @@ open class XmlParserGenerator( .write(".#T { #S }", Serde.getOrDeserializeErr, escapedErrMessage) .dedent() } + + private fun renderDeserializerUnwrappedXmlBody( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + shape: Shape, + writer: KotlinWriter, + ) { + val members = shape.members() + check(members.size == 1) { + "unwrapped XML output trait is only allowed on operation output structs with exactly one member" + } + + val member = members.first() + writer.withBlock("when(${serdeCtx.tagReader}.startTag.name.tag) {", "}") { + val name = member.getTrait()?.value ?: member.memberName + write("// ${member.memberName} ${escape(member.id.toString())}") + writeInline("#S -> builder.#L = ", name, member.defaultName()) + deserializeMember(ctx, serdeCtx, member, writer) + } + } } diff --git a/gradle.properties b/gradle.properties index 11f5dc5b1..ebe9af9e9 100644 --- a/gradle.properties +++ b/gradle.properties @@ -16,4 +16,4 @@ org.gradle.jvmargs=-Xmx2G -XX:MaxMetaspaceSize=1G sdkVersion=1.0.16-SNAPSHOT # codegen -codegenVersion=0.30.17-SNAPSHOT \ No newline at end of file +codegenVersion=0.30.17-SNAPSHOT diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt index 89df7fee3..9ec628331 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt @@ -9,24 +9,6 @@ import aws.smithy.kotlin.runtime.smithy.test.assertXmlStringsEqual import kotlin.test.assertEquals abstract class AbstractXmlTest { - // // FIXME - remove before merge - this test suite was put in place prior to changing the implementation to - // // verify everything works - // fun testRoundTrip( - // expected: T, - // payload: String, - // serializerFn: (XmlSerializer, T) -> Unit, - // deserializerFn: (XmlDeserializer) -> T, - // ) { - // val serializer = XmlSerializer() - // serializerFn(serializer, expected) - // val actualPayload = serializer.toByteArray().decodeToString() - // assertXmlStringsEqual(payload, actualPayload) - // - // val deserializer = XmlDeserializer(payload.encodeToByteArray()) - // val actualDeserialized = deserializerFn(deserializer) - // assertEquals(expected, actualDeserialized) - // } - fun testRoundTrip( expected: T, payload: String, From 9f71a169176430684fd9635b6e07c7f1ac576bf8 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 14:00:23 -0500 Subject: [PATCH 11/25] fix attribute lookup --- .../kotlin/runtime/serde/xml/XmlToken.kt | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt index 4b1981a68..c265a9617 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt @@ -30,6 +30,22 @@ public sealed class XmlToken { public data class QualifiedName(public val local: String, public val prefix: String? = null) { override fun toString(): String = tag + @InternalApi + public companion object { + + /** + * Construct a [QualifiedName] from a raw string representation + */ + public fun from(qualified: String): QualifiedName { + val split = qualified.split(":", limit = 2) + val (local, prefix) = when (split.size == 2) { + true -> split[1] to split[0] + false -> split[0] to null + } + return QualifiedName(local, prefix) + } + } + public val tag: String get() = when (prefix) { null -> local else -> "$prefix:$local" @@ -55,7 +71,7 @@ public sealed class XmlToken { override fun toString(): String = "<${this.name} (${this.depth})>" // convenience function for codegen - public fun getAttr(qualified: String): String? = attributes[QualifiedName(qualified)] + public fun getAttr(qualified: String): String? = attributes[QualifiedName.from(qualified)] } /** From 71ebe8655ff00d49d93bdd54525ac10d38b3e1c0 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 14:33:23 -0500 Subject: [PATCH 12/25] fix serde ctx --- .../kotlin/codegen/rendering/serde/XmlParserGenerator.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index bb5b12791..dac28c06c 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -307,12 +307,12 @@ open class XmlParserGenerator( } ShapeType.STRUCTURE, ShapeType.UNION -> { val deserializeFn = documentDeserializer(ctx, target) - writer.write("#T(curr)", deserializeFn) + writer.write("#T(${serdeCtx.tagReader})", deserializeFn) } else -> deserializePrimitiveMember( ctx, member, - writer.format("curr.#T()", SerdeXml.tryData), + writer.format("${serdeCtx.tagReader}.#T()", SerdeXml.tryData), textExprIsResult = true, writer, ) From 184439a7b1038cf0256d8e983b868d5360aefbe6 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Fri, 23 Feb 2024 17:09:03 -0500 Subject: [PATCH 13/25] drop XmlDeserializer --- .../kotlin/codegen/core/RuntimeTypes.kt | 1 - .../serde/XmlSerdeDescriptorGenerator.kt | 2 - .../serde/XmlSerdeDescriptorGeneratorTest.kt | 45 - .../xml/Ec2QueryErrorDeserializer.kt | 98 +- .../xml/RestXmlErrorDeserializer.kt | 128 +-- .../kotlin/runtime/serde/xml/TagReader.kt | 34 +- .../runtime/serde/xml/XmlDeserializer.kt | 415 -------- .../serde/xml/XmlPrimitiveDeserializer.kt | 76 -- .../runtime/serde/xml/SharedTestData.kt | 358 ------- .../kotlin/runtime/serde/xml/TagReaderTest.kt | 5 +- .../serde/xml/XmlDeserializerAWSTest.kt | 126 --- .../serde/xml/XmlDeserializerListTest.kt | 712 ------------- .../serde/xml/XmlDeserializerMapTest.kt | 795 --------------- .../serde/xml/XmlDeserializerNamespaceTest.kt | 107 -- .../serde/xml/XmlDeserializerPrimitiveTest.kt | 97 -- .../serde/xml/XmlDeserializerStructTest.kt | 404 -------- .../runtime/serde/xml/XmlSerializerTest.kt | 963 ------------------ .../serde/xml/XmlDeserializerBenchmark.kt | 5 +- .../serde/xml/XmlSerializerBenchmark.kt | 5 +- 19 files changed, 95 insertions(+), 4281 deletions(-) delete mode 100644 runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt delete mode 100644 runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt delete mode 100644 runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index bc2026eda..b21747fee 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -274,7 +274,6 @@ object RuntimeTypes { val XmlMapName = symbol("XmlMapName") val XmlError = symbol("XmlError") val XmlSerializer = symbol("XmlSerializer") - val XmlDeserializer = symbol("XmlDeserializer") val XmlUnwrappedOutput = symbol("XmlUnwrappedOutput") val TagReader = symbol("TagReader") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt index b5bbf4de7..a5f47cf3c 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.core.RenderingContext import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes -import software.amazon.smithy.kotlin.codegen.core.addImport import software.amazon.smithy.kotlin.codegen.core.defaultName import software.amazon.smithy.kotlin.codegen.model.expectShape import software.amazon.smithy.kotlin.codegen.model.expectTrait @@ -78,7 +77,6 @@ open class XmlSerdeDescriptorGenerator( nameSuffix: String, ): List { ctx.writer.addImport( - RuntimeTypes.Serde.SerdeXml.XmlDeserializer, RuntimeTypes.Serde.SerdeXml.XmlSerialName, ) diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt index edb4981aa..3da6ce551 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde -import software.amazon.smithy.kotlin.codegen.core.RUNTIME_ROOT_NS import software.amazon.smithy.kotlin.codegen.test.* import software.amazon.smithy.model.shapes.ShapeId import kotlin.test.Test @@ -141,50 +140,6 @@ class XmlSerdeDescriptorGeneratorTest { contents.shouldContainOnlyOnceWithDiff(expectedDescriptors) } - @Test - fun `it generates expected import declarations`() { - val snippet = """ - @http(method: "POST", uri: "/foo") - operation Foo { - input: FooRequest, - output: FooRequest - } - - @xmlName("CustomFooRequest") - structure FooRequest { - @xmlAttribute - payload: String, - @xmlFlattened - listVal: ListOfString - } - - list ListOfString { - member: String - } - """ - - val expected = """ - import $RUNTIME_ROOT_NS.serde.SdkFieldDescriptor - import $RUNTIME_ROOT_NS.serde.SdkObjectDescriptor - import $RUNTIME_ROOT_NS.serde.SerialKind - import $RUNTIME_ROOT_NS.serde.asSdkSerializable - import $RUNTIME_ROOT_NS.serde.deserializeList - import $RUNTIME_ROOT_NS.serde.deserializeMap - import $RUNTIME_ROOT_NS.serde.deserializeStruct - import $RUNTIME_ROOT_NS.serde.field - import $RUNTIME_ROOT_NS.serde.serializeList - import $RUNTIME_ROOT_NS.serde.serializeMap - import $RUNTIME_ROOT_NS.serde.serializeStruct - import $RUNTIME_ROOT_NS.serde.xml.Flattened - import $RUNTIME_ROOT_NS.serde.xml.XmlAttribute - import $RUNTIME_ROOT_NS.serde.xml.XmlDeserializer - import $RUNTIME_ROOT_NS.serde.xml.XmlSerialName - """.formatForTest("") - - val contents = getContents(snippet, "FooRequest") - contents.shouldContainOnlyOnceWithDiff(expected) - } - @Test fun `it generates field descriptors for flattened xml trait and object descriptor for XmlName trait`() { val snippet = """ diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt index c01e062e5..091d8a90a 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt @@ -6,94 +6,74 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails -import aws.smithy.kotlin.runtime.serde.* -import aws.smithy.kotlin.runtime.serde.xml.XmlCollectionName -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer -import aws.smithy.kotlin.runtime.serde.xml.XmlSerialName +import aws.smithy.kotlin.runtime.serde.getOrDeserializeErr +import aws.smithy.kotlin.runtime.serde.xml.* internal data class Ec2QueryErrorResponse(val errors: List, val requestId: String?) internal data class Ec2QueryError(val code: String?, val message: String?) @InternalApi -public suspend fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { - val response = Ec2QueryErrorResponseDeserializer.deserialize(XmlDeserializer(payload, true)) +public fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { + val response = Ec2QueryErrorResponseDeserializer.deserialize(xmlStreamReader(payload).root()) val firstError = response.errors.firstOrNull() return ErrorDetails(firstError?.code, firstError?.message, response.requestId) } /** * Deserializes EC2 Query protocol errors as specified by - * https://awslabs.github.io/smithy/1.0/spec/aws/aws-ec2-query-protocol.html#operation-error-serialization + * https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization */ internal object Ec2QueryErrorResponseDeserializer { - private val ERRORS_DESCRIPTOR = SdkFieldDescriptor( - SerialKind.List, - XmlSerialName("Errors"), - XmlCollectionName("Error"), - ) - private val REQUESTID_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("RequestId")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Response")) - field(ERRORS_DESCRIPTOR) - field(REQUESTID_DESCRIPTOR) - } - - suspend fun deserialize(deserializer: Deserializer): Ec2QueryErrorResponse { - var errors = listOf() + fun deserialize(root: TagReader): Ec2QueryErrorResponse = runCatching { + var errors: List? = null var requestId: String? = null + if (root.startTag.name.tag != "Response") error("expected found ${root.startTag}") - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ERRORS_DESCRIPTOR.index -> errors = deserializer.deserializeList(ERRORS_DESCRIPTOR) { - val collection = mutableListOf() - while (hasNextElement()) { - if (nextHasValue()) { - val element = Ec2QueryErrorDeserializer.deserialize(deserializer) - collection.add(element) - } else { - deserializeNull() - continue - } - } - collection - } - REQUESTID_DESCRIPTOR.index -> requestId = deserializeString() - null -> break@loop - else -> skipValue() - } + loop@while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.startTag.name.tag) { + "Errors" -> errors = Ec2QueryErrorListDeserializer.deserialize(curr) + "RequestId" -> requestId = curr.data() } + curr.drop() } - return Ec2QueryErrorResponse(errors, requestId) + Ec2QueryErrorResponse(errors ?: emptyList(), requestId) + }.getOrDeserializeErr { "Unable to deserialize EC2Query error" } +} + +internal object Ec2QueryErrorListDeserializer { + fun deserialize(root: TagReader): List { + val errors = mutableListOf() + loop@ while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.startTag.name.tag) { + "Error" -> { + val el = Ec2QueryErrorDeserializer.deserialize(curr) + errors.add(el) + } + } + curr.drop() + } + return errors } } internal object Ec2QueryErrorDeserializer { - private val CODE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Code")) - private val MESSAGE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Message")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Error")) - field(CODE_DESCRIPTOR) - field(MESSAGE_DESCRIPTOR) - } - suspend fun deserialize(deserializer: Deserializer): Ec2QueryError { + fun deserialize(root: TagReader): Ec2QueryError { var code: String? = null var message: String? = null - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - CODE_DESCRIPTOR.index -> code = deserializeString() - MESSAGE_DESCRIPTOR.index -> message = deserializeString() - null -> break@loop - else -> skipValue() - } + loop@ while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.startTag.name.tag) { + "Code" -> code = curr.data() + "Message", "message" -> message = curr.data() } + curr.drop() } - return Ec2QueryError(code, message) } } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt index 6eef5f463..f3cc9e47b 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt @@ -7,8 +7,10 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails import aws.smithy.kotlin.runtime.serde.* -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer -import aws.smithy.kotlin.runtime.serde.xml.XmlSerialName +import aws.smithy.kotlin.runtime.serde.xml.TagReader +import aws.smithy.kotlin.runtime.serde.xml.data +import aws.smithy.kotlin.runtime.serde.xml.root +import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader /** * Provides access to specific values regardless of message form @@ -19,16 +21,6 @@ internal interface RestXmlErrorDetails { val message: String? } -// Models "ErrorResponse" type in https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization -internal data class XmlErrorResponse( - val error: XmlError?, - override val requestId: String? = error?.requestId, -) : RestXmlErrorDetails { - override val code: String? = error?.code - override val message: String? = error?.message -} - -// Models "Error" type in https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization internal data class XmlError( override val requestId: String?, override val code: String?, @@ -39,96 +31,56 @@ internal data class XmlError( * Deserializes rest XML protocol errors as specified by: * https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#error-response-serialization * - * Returns parsed data in normalized form or throws IllegalArgumentException if response cannot be parsed. - * NOTE: we use an explicit XML deserializer here because we rely on validating the root element name - * for dealing with the alternate error response forms + * Returns parsed data in normalized form or throws [DeserializationException] if response cannot be parsed. */ @InternalApi -public suspend fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { - val details = ErrorResponseDeserializer.deserialize(XmlDeserializer(payload, true)) - ?: XmlErrorDeserializer.deserialize(XmlDeserializer(payload, true)) - ?: throw DeserializationException("Unable to deserialize RestXml error.") +public fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { + val details = XmlErrorDeserializer.deserialize(xmlStreamReader(payload).root()) return ErrorDetails(details.code, details.message, details.requestId) } -/* - * The deserializers in this file were initially generated by the SDK and then - * adapted to fit this use case of deserializing well-known error structures from - * restXml-based services. - */ - /** - * Deserializes rest Xml protocol errors as specified by: - * - Smithy spec: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization - */ -internal object ErrorResponseDeserializer { - private val ERROR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("Error")) - private val REQUESTID_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("RequestId")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("ErrorResponse")) - field(ERROR_DESCRIPTOR) - field(REQUESTID_DESCRIPTOR) - } - - suspend fun deserialize(deserializer: Deserializer): XmlErrorResponse? { - var requestId: String? = null - var xmlError: XmlError? = null - - return try { - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ERROR_DESCRIPTOR.index -> xmlError = XmlErrorDeserializer.deserialize(deserializer) - REQUESTID_DESCRIPTOR.index -> requestId = deserializeString() - null -> break@loop - else -> skipValue() - } - } - } - - XmlErrorResponse(xmlError, requestId ?: xmlError?.requestId) - } catch (e: DeserializationException) { - null // return so an appropriate exception type can be instantiated above here. - } - } -} - -/** - * This deserializer is used for both the nested Error node from ErrorResponse as well as the top-level - * Error node as described in https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization + * This deserializer is used for both wrapped and unwrapped restXml errors. */ internal object XmlErrorDeserializer { - private val MESSAGE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Message")) - private val CODE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Code")) - private val REQUESTID_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("RequestId")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Error")) - field(MESSAGE_DESCRIPTOR) - field(CODE_DESCRIPTOR) - field(REQUESTID_DESCRIPTOR) - } - - suspend fun deserialize(deserializer: Deserializer): XmlError? { + fun deserialize(root: TagReader): XmlError = runCatching { var message: String? = null var code: String? = null var requestId: String? = null - return try { - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - MESSAGE_DESCRIPTOR.index -> message = deserializeString() - CODE_DESCRIPTOR.index -> code = deserializeString() - REQUESTID_DESCRIPTOR.index -> requestId = deserializeString() - null -> break@loop - else -> skipValue() - } + val rootTagName = root.startTag.name.tag + check(rootTagName == "ErrorResponse" || rootTagName == "Error") { + "expected restXml error response with root tag of or " + } + + // wrapped error, unwrap it + var errTag = root + if (root.startTag.name.tag == "ErrorResponse") { + errTag = root.nextTag() ?: error("expected more tags after ") + } + + if (errTag.startTag.name.tag == "Error") { + loop@ while (true) { + val curr = errTag.nextTag() ?: break@loop + when (curr.startTag.name.tag) { + "Code" -> code = curr.data() + "Message", "message" -> message = curr.data() + "RequestId" -> requestId = curr.data() } + curr.drop() } + } - XmlError(requestId, code, message) - } catch (e: DeserializationException) { - null // return so an appropriate exception type can be instantiated above here. + // wrapped responses + if (requestId == null) { + loop@while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.startTag.name.tag) { + "RequestId" -> requestId = curr.data() + } + } } - } + + XmlError(requestId, code, message) + }.getOrDeserializeErr { "Unable to deserialize RestXml error" } } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt index ad3a30cb0..e0caf1fa3 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt @@ -104,6 +104,9 @@ public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): TagReader { public inline fun TagReader.mapData(transform: (String) -> T): T = transform(data()) +/** + * Unwrap the next token as [XmlToken.Text] and return its' value or throw a [DeserializationException] + */ @InternalApi public fun TagReader.data(): String = when (val next = nextToken()) { @@ -112,32 +115,9 @@ public fun TagReader.data(): String = else -> throw DeserializationException("expected XmlToken.Text element, found $next") } +/** + * Attempt to get the text token as [XmlToken.Text] and return a result containing its' value on success + * or the exception thrown on failure. + */ @InternalApi public fun TagReader.tryData(): Result = runCatching { data() } - -// -// private fun TagReader.mapOrThrow(expected: String, mapper: (String) -> T?): T = -// map { raw -> -// mapper(raw) ?: throw DeserializationException("could not deserialize $raw as $expected for tag ${this.startTag}") -// } -// -// @InternalApi -// public fun TagReader.readInt(): Int = mapOrThrow("Int", String::toIntOrNull) -// -// @InternalApi -// public fun TagReader.readShort(): Short = mapOrThrow("Short", String::toShortOrNull) -// -// @InternalApi -// public fun TagReader.readLong(): Long = mapOrThrow("Long", String::toLongOrNull) -// -// @InternalApi -// public fun TagReader.readFloat(): Float = mapOrThrow("Float", String::toFloatOrNull) -// -// @InternalApi -// public fun TagReader.readDouble(): Double = mapOrThrow("Double", String::toDoubleOrNull) -// -// @InternalApi -// public fun TagReader.readByte(): Byte = mapOrThrow("Byte") { it.toIntOrNull()?.toByte() } -// -// @InternalApi -// public fun TagReader.readBoolean(): Boolean = mapOrThrow("Boolean", String::toBoolean) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt deleted file mode 100644 index e9b58efe2..000000000 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt +++ /dev/null @@ -1,415 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.InternalApi -import aws.smithy.kotlin.runtime.content.BigDecimal -import aws.smithy.kotlin.runtime.content.BigInteger -import aws.smithy.kotlin.runtime.content.Document -import aws.smithy.kotlin.runtime.serde.* - -private const val FIRST_FIELD_INDEX: Int = 0 - -// Represents aspects of SdkFieldDescriptor that are particular to the Xml format -internal sealed class FieldLocation { - // specifies the mapping to a sdk field index - abstract val fieldIndex: Int - - data class Text(override val fieldIndex: Int) : FieldLocation() // Xml nodes have only one associated Text element - data class Attribute(override val fieldIndex: Int, val names: Set) : FieldLocation() -} - -/** - * Provides a deserializer for XML documents - * - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param validateRootElement Flag indicating if the root XML document [XmlToken.BeginElement] should be validated against - * the descriptor passed to [deserializeStruct]. This only affects the root element, not nested struct elements. Some - * restXml based services DO NOT always send documents with a root element name that matches the shape ID name - * (S3 in particular). This means there is nothing in the model that gives you enough information to validate the tag. - */ -@InternalApi -public class XmlDeserializer( - private val reader: XmlStreamReader, - private val validateRootElement: Boolean = false, -) : Deserializer { - - public constructor(input: ByteArray, validateRootElement: Boolean = false) : this(xmlStreamReader(input), validateRootElement) - - private var firstStructCall = true - - override fun deserializeStruct(descriptor: SdkObjectDescriptor): Deserializer.FieldIterator { - if (firstStructCall) { - if (!descriptor.hasTrait()) throw DeserializationException("Top-level struct $descriptor requires a XmlSerialName trait but has none.") - - firstStructCall = false - - reader.nextToken() // Matching field descriptors to children tags so consume the start element of top-level struct - - val structToken = if (descriptor.hasTrait()) { - reader.seek { it.name == descriptor.expectTrait().errorTag } - } else { - reader.seek() - } ?: throw DeserializationException("Could not find a begin element for new struct") - - if (validateRootElement) { - descriptor.requireNameMatch(structToken.name.tag) - } - } - - // Consume any remaining terminating tokens from previous deserialization - reader.seek() - - // Because attributes set on the root node of the struct, we must read the values before creating the subtree - val attribFields = reader.tokenAttributesToFieldLocations(descriptor) - val parentToken = if (reader.lastToken is XmlToken.BeginElement) { - reader.lastToken as XmlToken.BeginElement - } else { - throw DeserializationException("Expected last parsed token to be ${XmlToken.BeginElement::class} but was ${reader.lastToken}") - } - - val unwrapped = descriptor.hasTrait() - return XmlStructDeserializer(descriptor, reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CURRENT), parentToken, attribFields, unwrapped) - } - - override fun deserializeList(descriptor: SdkFieldDescriptor): Deserializer.ElementIterator { - val depth = when (descriptor.hasTrait()) { - true -> XmlStreamReader.SubtreeStartDepth.CURRENT - else -> XmlStreamReader.SubtreeStartDepth.CHILD - } - - return XmlListDeserializer(reader.subTreeReader(depth), descriptor) - } - - override fun deserializeMap(descriptor: SdkFieldDescriptor): Deserializer.EntryIterator { - val depth = when (descriptor.hasTrait()) { - true -> XmlStreamReader.SubtreeStartDepth.CURRENT - else -> XmlStreamReader.SubtreeStartDepth.CHILD - } - - return XmlMapDeserializer(reader.subTreeReader(depth), descriptor) - } -} - -/** - * Deserializes specific XML structures into forms that can produce Maps - * - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param descriptor associated [SdkFieldDescriptor] which represents the expected Map - * @param primitiveDeserializer used to deserialize primitive values - */ -internal class XmlMapDeserializer( - private val reader: XmlStreamReader, - private val descriptor: SdkFieldDescriptor, - private val primitiveDeserializer: PrimitiveDeserializer = XmlPrimitiveDeserializer(reader, descriptor), -) : PrimitiveDeserializer by primitiveDeserializer, Deserializer.EntryIterator { - private val mapTrait = descriptor.findTrait() ?: XmlMapName.Default - - override fun hasNextEntry(): Boolean { - val compareTo = when (descriptor.hasTrait()) { - true -> descriptor.findTrait()?.name ?: mapTrait.key // Prefer seeking to XmlSerialName if the trait exists - false -> mapTrait.entry - } - - // Seek to either the XML serial name, entry, or key token depending on the flatness of the map and if the name trait is present - val nextEntryToken = when (descriptor.hasTrait()) { - true -> reader.peekSeek { it.name.local == compareTo } - false -> reader.seek { it.name.local == compareTo } - } - - return nextEntryToken != null - } - - override fun key(): String { - // Seek to the key begin token - reader.seek { it.name.local == mapTrait.key } - ?: error("Unable to find key $mapTrait.key in $descriptor") - - val keyValueToken = reader.takeNextAs() - reader.nextToken() // Consume the end wrapper - - return keyValueToken.value ?: throw DeserializationException("Key unspecified in $descriptor") - } - - override fun nextHasValue(): Boolean { - // Expect a begin and value (or another begin) token if Map entry has a value - val peekBeginToken = reader.peek(1) ?: throw DeserializationException("Unexpected termination of token stream in $descriptor") - val peekValueToken = reader.peek(2) ?: throw DeserializationException("Unexpected termination of token stream in $descriptor") - - return peekBeginToken !is XmlToken.EndElement && peekValueToken !is XmlToken.EndElement - } -} - -/** - * Deserializes specific XML structures into forms that can produce Lists - * - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param descriptor associated [SdkFieldDescriptor] which represents the expected Map - * @param primitiveDeserializer used to deserialize primitive values - */ -internal class XmlListDeserializer( - private val reader: XmlStreamReader, - private val descriptor: SdkFieldDescriptor, - private val primitiveDeserializer: PrimitiveDeserializer = XmlPrimitiveDeserializer(reader, descriptor), -) : PrimitiveDeserializer by primitiveDeserializer, Deserializer.ElementIterator { - private var firstCall = true - private val flattened = descriptor.hasTrait() - private val elementName = (descriptor.findTrait() ?: XmlCollectionName.Default).element - - override fun hasNextElement(): Boolean { - if (!flattened && firstCall) { - val nextToken = reader.peek() - val matchedListDescriptor = nextToken is XmlToken.BeginElement && descriptor.nameMatches(nextToken.name.tag) - val hasChildren = if (nextToken == null) false else nextToken.depth >= reader.lastToken!!.depth - - if (!matchedListDescriptor && !hasChildren) return false - - // Discard the wrapper and move to the first element in the list - if (matchedListDescriptor) reader.nextToken() - - firstCall = false - } - - if (flattened) { - // Because our subtree is not CHILD, we cannot rely on the subtree boundary to determine end of collection. - // Rather, we search for either the next begin token matching the (flat) list member name which should - // be immediately after the current token - - // peek at the next token if there is one, in the case of a list of structs, the next token is actually - // the end of the current flat list element in which case we need to peek twice - val next = when (val peeked = reader.peek()) { - is XmlToken.EndElement -> { - if (peeked.name.local == descriptor.serialName.name) { - // consume the end token - reader.nextToken() - reader.peek() - } else { - peeked - } - } - else -> peeked - } - - val tokens = listOfNotNull(reader.lastToken, next) - - // Iterate over the token stream until begin token matching name is found or end element matching list is found. - return tokens - .filterIsInstance() - .any { it.name.local == descriptor.serialName.name } - } else { - // If we can find another begin token w/ the element name, we have more elements to process - return reader.seek { it.name.local == elementName }.isNotTerminal() - } - } - - override fun nextHasValue(): Boolean = reader.peek() !is XmlToken.EndElement -} - -/** - * Deserializes specific XML structures into forms that can produce structures - * - * @param objDescriptor associated [SdkObjectDescriptor] which represents the expected structure - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param parentToken initial token of associated structure - * @param parsedFieldLocations list of [FieldLocation] representing values able to be loaded into deserialized instances - */ -private class XmlStructDeserializer( - private val objDescriptor: SdkObjectDescriptor, - reader: XmlStreamReader, - private val parentToken: XmlToken.BeginElement, - private val parsedFieldLocations: MutableList = mutableListOf(), - private val unwrapped: Boolean, -) : Deserializer.FieldIterator { - // Used to track direct deserialization or further nesting between calls to findNextFieldIndex() and deserialize() - private var reentryFlag: Boolean = false - - private val reader: XmlStreamReader = if (unwrapped) reader else reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CHILD) - - override fun findNextFieldIndex(): Int? { - if (unwrapped) { - return if (reader.peek() is XmlToken.Text) FIRST_FIELD_INDEX else null - } - if (inNestedMode()) { - // Returning from a nested struct call. Nested deserializer consumed - // tokens so clear them here to avoid processing stale state - parsedFieldLocations.clear() - } - - if (parsedFieldLocations.isEmpty()) { - val matchedFieldLocations = when (val token = reader.nextToken()) { - null, is XmlToken.EndDocument -> return null - is XmlToken.EndElement -> return findNextFieldIndex() - is XmlToken.BeginElement -> { - val nextToken = reader.peek() ?: return null - val objectFields = objDescriptor.fields - val memberFields = objectFields.filter { field -> objDescriptor.fieldTokenMatcher(field, token) } - val matchingFields = memberFields.mapNotNull { it.findFieldLocation(token, nextToken) } - matchingFields - } - else -> return findNextFieldIndex() - } - - // Sorting ensures attribs are processed before text, as processing the Text token pushes the parser on to the next token. - parsedFieldLocations.addAll(matchedFieldLocations.sortedBy { it is FieldLocation.Text }) - } - - return parsedFieldLocations.firstOrNull()?.fieldIndex ?: Deserializer.FieldIterator.UNKNOWN_FIELD - } - - private fun deserializeValue(transform: ((String) -> T)): T { - if (unwrapped) { - val value = reader.takeNextAs().value ?: "" - return transform(value) - } - // Set and validate mode - reentryFlag = false - if (parsedFieldLocations.isEmpty()) throw DeserializationException("matchedFields is empty, was findNextFieldIndex() called?") - - // Take the first FieldLocation and attempt to parse it into the value specified by the descriptor. - return when (val nextField = parsedFieldLocations.removeFirst()) { - is FieldLocation.Text -> { - val value = when (val peekToken = reader.peek()) { - is XmlToken.Text -> reader.takeNextAs().value ?: "" - is XmlToken.EndElement -> "" - else -> throw DeserializationException("Unexpected token $peekToken") - } - transform(value) - } - is FieldLocation.Attribute -> { - transform( - nextField - .names - .mapNotNull { parentToken.attributes[it] } - .firstOrNull() ?: throw DeserializationException("Expected attrib value ${nextField.names.first()} not found in ${parentToken.name}"), - ) - } - } - } - - override fun skipValue() = reader.skipNext() - - override fun deserializeByte(): Byte = deserializeValue { it.toIntOrNull()?.toByte() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeInt(): Int = deserializeValue { it.toIntOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeShort(): Short = deserializeValue { it.toIntOrNull()?.toShort() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeLong(): Long = deserializeValue { it.toLongOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeFloat(): Float = deserializeValue { it.toFloatOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeDouble(): Double = deserializeValue { it.toDoubleOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeBigInteger(): BigInteger = deserializeValue { - runCatching { BigInteger(it) } - .getOrElse { throw DeserializationException("Unable to deserialize $it as BigInteger") } - } - - override fun deserializeBigDecimal(): BigDecimal = deserializeValue { - runCatching { BigDecimal(it) } - .getOrElse { throw DeserializationException("Unable to deserialize $it as BigDecimal") } - } - - override fun deserializeString(): String = deserializeValue { it } - - override fun deserializeBoolean(): Boolean = deserializeValue { it.toBoolean() } - - override fun deserializeDocument(): Document { - throw DeserializationException("cannot deserialize unsupported Document type in xml") - } - - override fun deserializeNull(): Nothing? { - reader.takeNextAs() - return null - } - - // A struct deserializer can be called in two "modes": - // 1. to deserialize a value. This calls findNextFieldIndex() followed by deserialize() - // 2. to deserialize a nested container. This calls findNextFieldIndex() followed by a call to another deserialize() - // Because state is built in findNextFieldIndex() that is intended to be used directly in deserialize() (mode 1) - // and there is no explicit way that this type knows which mode is in use, the state built must be cleared. - // this is done by flipping a bit between the two calls. If the bit has not been flipped on any call to findNextFieldIndex() - // it is determined that the nested mode was used and any existing state should be cleared. - // if the state is not cleared, deserialization goes into an infinite loop because the deserializer sees pending fields to pull from the stream - // which are never consumed by the (missing) call to deserialize() - private fun inNestedMode(): Boolean = when (reentryFlag) { - true -> true - false -> { reentryFlag = true; false } - } -} - -// Extract the attributes from the last-read token and match them to [FieldLocation] on the [SdkObjectDescriptor]. -private fun XmlStreamReader.tokenAttributesToFieldLocations(descriptor: SdkObjectDescriptor): MutableList = - if (descriptor.hasXmlAttributes && lastToken is XmlToken.BeginElement) { - val attribFields = descriptor.fields.filter { it.hasTrait() } - val matchedAttribFields = attribFields.filter { it.findFieldLocation(lastToken as XmlToken.BeginElement, peek() ?: throw DeserializationException("Unexpected end of tokens")) != null } - matchedAttribFields.map { FieldLocation.Attribute(it.index, it.toQualifiedNames()) } - .toMutableList() - } else { - mutableListOf() - } - -// Returns a [FieldLocation] if the field maps to the current token -private fun SdkFieldDescriptor.findFieldLocation( - currentToken: XmlToken.BeginElement, - nextToken: XmlToken, -): FieldLocation? = when (val property = toFieldLocation()) { - is FieldLocation.Text -> { - when { - nextToken is XmlToken.Text -> property - nextToken is XmlToken.BeginElement -> property - // The following allows for struct primitives to remain unvisited if no value - // but causes nested deserializers to be called even if they contain no value - nextToken is XmlToken.EndElement && currentToken.name == nextToken.name -> property - else -> null - } - } - is FieldLocation.Attribute -> { - val foundMatch = property.names.any { currentToken.attributes[it]?.isNotBlank() == true } - if (foundMatch) property else null - } -} - -// Produce a [FieldLocation] type based on presence of traits of field -// A field without an attribute trait is assumed to be a text token -private fun SdkFieldDescriptor.toFieldLocation(): FieldLocation = - when (findTrait()) { - null -> FieldLocation.Text(index) // Assume a text value if no attributes defined. - else -> FieldLocation.Attribute(index, toQualifiedNames()) - } - -// Matches fields and tokens with matching qualified name -private fun SdkObjectDescriptor.fieldTokenMatcher(fieldDescriptor: SdkFieldDescriptor, beginElement: XmlToken.BeginElement): Boolean { - if (fieldDescriptor.kind == SerialKind.List && fieldDescriptor.hasTrait()) { - val fieldName = fieldDescriptor.findTrait() ?: XmlCollectionName.Default - val tokenQname = beginElement.name - - // It may be that we are matching a flattened list element or matching a list itself. In the latter - // case the following predicate will not work, so if we fail to match the member - // try again (below) to match against the container. - if (fieldName.element == tokenQname.local) return true - } - - return fieldDescriptor.nameMatches(beginElement.name.tag) -} - -/** - * Return the next token of the specified type or throw [DeserializationException] if incorrect type. - */ -internal inline fun XmlStreamReader.takeNextAs(): TExpected { - val token = this.nextToken() ?: throw DeserializationException("Expected ${TExpected::class} but instead found null") - requireToken(token) - return token as TExpected -} - -/** - * Require that the given token be of type [TExpected] or else throw an exception - */ -internal inline fun requireToken(token: XmlToken) { - if (token::class != TExpected::class) { - throw DeserializationException("Expected ${TExpected::class}; found ${token::class} ($token)") - } -} diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt deleted file mode 100644 index 1707fdf32..000000000 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.content.BigDecimal -import aws.smithy.kotlin.runtime.content.BigInteger -import aws.smithy.kotlin.runtime.content.Document -import aws.smithy.kotlin.runtime.serde.* - -/** - * Deserialize primitive values for single values, lists, and maps - */ -internal class XmlPrimitiveDeserializer(private val reader: XmlStreamReader, private val fieldDescriptor: SdkFieldDescriptor) : - PrimitiveDeserializer { - - constructor(input: ByteArray, fieldDescriptor: SdkFieldDescriptor) : this(xmlStreamReader(input), fieldDescriptor) - - private fun deserializeValue(transform: ((String) -> T)): T { - if (reader.peek() is XmlToken.BeginElement) { - // In the case of flattened lists, we "fall" into the first member as there is no wrapper. - // this conditional checks that case for the first element of the list. - val wrapperToken = reader.takeNextAs() - if (wrapperToken.name.local != fieldDescriptor.generalName()) { - // Depending on flat/not-flat, may need to consume multiple start tokens - return deserializeValue(transform) - } - } - - val token = reader.takeNextAs() - - return token.value - ?.let { transform(it) } - ?.also { reader.takeNextAs() } ?: throw DeserializationException("$token specifies nonexistent or invalid value.") - } - - override fun deserializeByte(): Byte = deserializeValue { it.toIntOrNull()?.toByte() ?: throw DeserializationException("Unable to deserialize $it as Byte") } - - override fun deserializeInt(): Int = deserializeValue { it.toIntOrNull() ?: throw DeserializationException("Unable to deserialize $it as Int") } - - override fun deserializeShort(): Short = deserializeValue { it.toIntOrNull()?.toShort() ?: throw DeserializationException("Unable to deserialize $it as Short") } - - override fun deserializeLong(): Long = deserializeValue { it.toLongOrNull() ?: throw DeserializationException("Unable to deserialize $it as Long") } - - override fun deserializeFloat(): Float = deserializeValue { it.toFloatOrNull() ?: throw DeserializationException("Unable to deserialize $it as Float") } - - override fun deserializeDouble(): Double = deserializeValue { it.toDoubleOrNull() ?: throw DeserializationException("Unable to deserialize $it as Double") } - - override fun deserializeBigInteger(): BigInteger = deserializeValue { - runCatching { BigInteger(it) } - .getOrElse { throw DeserializationException("Unable to deserialize $it as BigInteger") } - } - - override fun deserializeBigDecimal(): BigDecimal = deserializeValue { - runCatching { BigDecimal(it) } - .getOrElse { throw DeserializationException("Unable to deserialize $it as BigDecimal") } - } - - override fun deserializeString(): String = deserializeValue { it } - - override fun deserializeBoolean(): Boolean = deserializeValue { it.toBoolean() } - - override fun deserializeDocument(): Document { - throw DeserializationException("cannot deserialize unsupported Document type in xml") - } - - override fun deserializeNull(): Nothing? { - reader.nextToken() ?: throw DeserializationException("Unexpected end of stream") - reader.seek() - reader.nextToken() ?: throw DeserializationException("Unexpected end of stream") - - return null - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt deleted file mode 100644 index 94a905df3..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt +++ /dev/null @@ -1,358 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* - -class SimpleStructClass { - var x: Int? = null - var y: Int? = null - var z: String? = null - - // Only for testing, not serialization - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("y")) - val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("z"), XmlAttribute) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - field(Z_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): SimpleStructClass { - val result = SimpleStructClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeInt() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - Z_DESCRIPTOR.index -> result.z = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class SimpleStructOfStringsClass { - var x: String? = null - var y: String? = null - - // Only for testing, not serialization - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("y")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): SimpleStructOfStringsClass { - val result = SimpleStructOfStringsClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeString() - Y_DESCRIPTOR.index -> result.y = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class StructWithAttribsClass { - var x: Int? = null - var y: Int? = null - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x"), XmlAttribute) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("y"), XmlAttribute) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): StructWithAttribsClass { - val result = StructWithAttribsClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeInt() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - result.unknownFieldCount++ - skipValue() - } - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class StructWithMultiAttribsAndTextValClass { - var x: Int? = null - var y: Int? = null - var txt: String? = null - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("xval"), XmlAttribute) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("yval"), XmlAttribute) - val TXT_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(TXT_DESCRIPTOR) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): StructWithMultiAttribsAndTextValClass { - val result = StructWithMultiAttribsAndTextValClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeInt() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - TXT_DESCRIPTOR.index -> result.txt = deserializeString() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - result.unknownFieldCount++ - skipValue() - } - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class RecursiveShapesInputOutput private constructor(builder: Builder) { - val nested: RecursiveShapesInputOutputNested1? = builder.nested - - companion object { - operator fun invoke(block: Builder.() -> kotlin.Unit): RecursiveShapesInputOutput = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("RecursiveShapesInputOutput(") - append("nested=$nested)") - } - - override fun hashCode(): kotlin.Int { - var result = nested?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as RecursiveShapesInputOutput - - if (nested != other.nested) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): RecursiveShapesInputOutput = Builder(this).apply(block).build() - - public class Builder() { - var nested: RecursiveShapesInputOutputNested1? = null - - constructor(x: RecursiveShapesInputOutput) : this() { - this.nested = x.nested - } - - fun build(): RecursiveShapesInputOutput = RecursiveShapesInputOutput(this) - } -} - -class RecursiveShapesInputOutputNested1 private constructor(builder: Builder) { - val foo: String? = builder.foo - val nested: RecursiveShapesInputOutputNested2? = builder.nested - - companion object { - fun dslBuilder(): Builder = Builder() - - operator fun invoke(block: Builder.() -> kotlin.Unit): RecursiveShapesInputOutputNested1 = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("RecursiveShapesInputOutputNested1(") - append("foo=$foo,") - append("nested=$nested)") - } - - override fun hashCode(): kotlin.Int { - var result = foo?.hashCode() ?: 0 - result = 31 * result + (nested?.hashCode() ?: 0) - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as RecursiveShapesInputOutputNested1 - - if (foo != other.foo) return false - if (nested != other.nested) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): RecursiveShapesInputOutputNested1 = Builder(this).apply(block).build() - - public class Builder() { - var foo: String? = null - var nested: RecursiveShapesInputOutputNested2? = null - - constructor(x: RecursiveShapesInputOutputNested1) : this() { - this.foo = x.foo - this.nested = x.nested - } - - fun build(): RecursiveShapesInputOutputNested1 = RecursiveShapesInputOutputNested1(this) - } -} - -class RecursiveShapesInputOutputNested2 private constructor(builder: Builder) { - val bar: String? = builder.bar - val recursiveMember: RecursiveShapesInputOutputNested1? = builder.recursiveMember - - companion object { - fun dslBuilder(): Builder = Builder() - - operator fun invoke(block: Builder.() -> kotlin.Unit): RecursiveShapesInputOutputNested2 = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("RecursiveShapesInputOutputNested2(") - append("bar=$bar,") - append("recursiveMember=$recursiveMember)") - } - - override fun hashCode(): kotlin.Int { - var result = bar?.hashCode() ?: 0 - result = 31 * result + (recursiveMember?.hashCode() ?: 0) - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as RecursiveShapesInputOutputNested2 - - if (bar != other.bar) return false - if (recursiveMember != other.recursiveMember) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): RecursiveShapesInputOutputNested2 = Builder(this).apply(block).build() - - public class Builder() { - var bar: String? = null - var recursiveMember: RecursiveShapesInputOutputNested1? = null - - constructor(x: RecursiveShapesInputOutputNested2) : this() { - this.bar = x.bar - this.recursiveMember = x.recursiveMember - } - - fun build(): RecursiveShapesInputOutputNested2 = RecursiveShapesInputOutputNested2(this) - } -} - -/* - @xmlNamespace(uri: "http://foo.com") - structure XmlNamespacesInputOutput { - nested: XmlNamespaceNested - } - - // Ignored since it's not at the top-level - @xmlNamespace(uri: "http://foo.com") - structure XmlNamespaceNested { - @xmlNamespace(uri: "http://baz.com", prefix: "baz") - foo: String, - - @xmlNamespace(uri: "http://qux.com") - values: XmlNamespacedList - } - - list XmlNamespacedList { - @xmlNamespace(uri: "http://bux.com") - member: String, - } -*/ -class XmlNamespacesRequest(val nested: XmlNamespaceNested?) { - companion object { - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("XmlNamespacesInputOutput")) - trait(XmlNamespace("http://foo.com")) - field(NESTED_DESCRIPTOR) - } - } - - fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - nested?.let { field(NESTED_DESCRIPTOR, XmlNamespaceNestedDocumentSerializer(it)) } - } - } -} - -data class XmlNamespaceNested( - val foo: String? = null, - val values: List? = null, -) - -internal class XmlNamespaceNestedDocumentSerializer(val input: XmlNamespaceNested) : SdkSerializable { - - companion object { - private val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo"), XmlNamespace("http://baz.com", "baz")) - private val VALUES_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("values"), XmlNamespace("http://qux.com"), XmlCollectionValueNamespace("http://bux.com")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("XmlNamespaceNested")) - trait(XmlNamespace("http://foo.com")) - field(FOO_DESCRIPTOR) - field(VALUES_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.foo?.let { field(FOO_DESCRIPTOR, it) } - if (input.values != null) { - listField(VALUES_DESCRIPTOR) { - for (el0 in input.values) { - serializeString(el0) - } - } - } - } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt index 411ec6a66..ca22315be 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt @@ -4,6 +4,7 @@ */ package aws.smithy.kotlin.runtime.serde.xml +import aws.smithy.kotlin.runtime.serde.parseInt import kotlin.test.* class TagReaderTest { @@ -119,8 +120,8 @@ class TagReaderTest { val curr = decoder.nextTag() ?: break@loop when (curr.startTag.name.tag) { "Child1" -> { - assertEquals(1, curr.nextTag()?.readInt()) - assertEquals(2, curr.nextTag()?.readInt()) + assertEquals(1, curr.nextTag()?.data()?.parseInt()?.getOrNull()) + assertEquals(2, curr.nextTag()?.data()?.parseInt()?.getOrNull()) } "Child2" -> { assertEquals("this is an a", curr.nextTag()?.data()) diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt deleted file mode 100644 index bf7b1ee36..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class XmlDeserializerAWSTest { - - class HostedZoneConfig private constructor(builder: Builder) { - val comment: String? = builder.comment - - companion object { - val COMMENT_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Comment")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("HostedZoneConfig")) - trait(XmlNamespace("https://route53.amazonaws.com/doc/2013-04-01/")) - field(COMMENT_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): HostedZoneConfig { - val builder = Builder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - COMMENT_DESCRIPTOR.index -> builder.comment = deserializeString() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - } - else -> throw DeserializationException(IllegalStateException("unexpected field index in HostedZoneConfig deserializer")) - } - } - } - return HostedZoneConfig(builder) - } - - operator fun invoke(block: Builder.() -> Unit) = Builder().apply(block).build() - } - - public class Builder { - var comment: String? = null - - fun build(): HostedZoneConfig = HostedZoneConfig(this) - } - } - - class CreateHostedZoneRequest private constructor(builder: Builder) { - val name: String? = builder.name - val callerReference: String? = builder.callerReference - val hostedZoneConfig: HostedZoneConfig? = builder.hostedZoneConfig - - companion object { - val NAME_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Name")) - val CALLER_REFERENCE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("CallerReference")) - val HOSTED_ZONE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("HostedZoneConfig")) - - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("CreateHostedZoneRequest")) - trait(XmlNamespace("https://route53.amazonaws.com/doc/2013-04-01/")) - field(NAME_DESCRIPTOR) - field(CALLER_REFERENCE_DESCRIPTOR) - field(HOSTED_ZONE_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): CreateHostedZoneRequest { - val builder = Builder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - NAME_DESCRIPTOR.index -> builder.name = deserializeString() - CALLER_REFERENCE_DESCRIPTOR.index -> builder.callerReference = deserializeString() - HOSTED_ZONE_DESCRIPTOR.index -> - builder.hostedZoneConfig = HostedZoneConfig.deserialize(deserializer) - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> skipValue() - else -> throw DeserializationException(IllegalStateException("unexpected field index in CreateHostedZoneRequest deserializer")) - } - } - } - return builder.build() - } - - operator fun invoke(block: Builder.() -> Unit) = Builder().apply(block).build() - } - - public class Builder { - var name: String? = null - var callerReference: String? = null - var hostedZoneConfig: HostedZoneConfig? = null - - fun build(): CreateHostedZoneRequest = CreateHostedZoneRequest(this) - } - } - - @Test - fun itHandlesRoute53XML() { - val testXml = """ - - - - java.sdk.com. - a322f752-8156-4746-8c04-e174ca1f51ce - - comment - - - """.trimIndent() - - val unit = XmlDeserializer(testXml.encodeToByteArray()) - - val createHostedZoneRequest = CreateHostedZoneRequest.deserialize(unit) - - assertTrue(createHostedZoneRequest.name == "java.sdk.com.") - assertTrue(createHostedZoneRequest.callerReference == "a322f752-8156-4746-8c04-e174ca1f51ce") - assertNotNull(createHostedZoneRequest.hostedZoneConfig) - assertTrue(createHostedZoneRequest.hostedZoneConfig.comment == "comment") - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt deleted file mode 100644 index 2ba7052e6..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt +++ /dev/null @@ -1,712 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import io.kotest.matchers.collections.shouldContainExactly -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class XmlDeserializerListTest { - - class ListDeserializer private constructor(builder: BuilderImpl) { - val list: List? = builder.list - - companion object { - operator fun invoke(block: DslBuilder.() -> Unit) = BuilderImpl().apply(block).build() - fun dslBuilder(): DslBuilder = BuilderImpl() - - fun deserialize( - deserializer: Deserializer, - OBJ_DESCRIPTOR: SdkObjectDescriptor, - ELEMENT_LIST_FIELD_DESCRIPTOR: SdkFieldDescriptor, - ): ListDeserializer { - val builder = dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ELEMENT_LIST_FIELD_DESCRIPTOR.index -> - builder.list = - deserializer.deserializeList(ELEMENT_LIST_FIELD_DESCRIPTOR) { - val list = mutableListOf() - while (hasNextElement()) { - list.add(deserializeInt()) - } - return@deserializeList list - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } - } - - interface Builder { - fun build(): ListDeserializer - // TODO - Java fill in Java builder - } - - interface DslBuilder { - var list: List? - - fun build(): ListDeserializer - } - - private class BuilderImpl : Builder, DslBuilder { - override var list: List? = null - - override fun build(): ListDeserializer = ListDeserializer(this) - } - } - - @Test - fun itHandlesListSingleElement() { - val payload = """ - - - 1 - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = listOf(1) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesListMultipleElementsAndCustomMemberName() { - val payload = """ - - - 1 - 2 - 3 - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), XmlCollectionName("element")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = listOf(1, 2, 3) - - actual.shouldContainExactly(expected) - } - - class SparseListDeserializer private constructor(builder: BuilderImpl) { - val list: List? = builder.list - - companion object { - operator fun invoke(block: DslBuilder.() -> Unit) = BuilderImpl().apply(block).build() - fun dslBuilder(): DslBuilder = BuilderImpl() - - fun deserialize( - deserializer: Deserializer, - OBJ_DESCRIPTOR: SdkObjectDescriptor, - ELEMENT_LIST_FIELD_DESCRIPTOR: SdkFieldDescriptor, - ): SparseListDeserializer { - val builder = dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ELEMENT_LIST_FIELD_DESCRIPTOR.index -> - builder.list = - deserializer.deserializeList(ELEMENT_LIST_FIELD_DESCRIPTOR) { - val col0 = mutableListOf() - while (hasNextElement()) { - val el0 = if (nextHasValue()) { - deserializeInt() - } else { - deserializeNull() - } - col0.add(el0) - } - col0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } - } - - interface Builder { - fun build(): SparseListDeserializer - // TODO - Java fill in Java builder - } - - interface DslBuilder { - var list: List? - - fun build(): SparseListDeserializer - } - - private class BuilderImpl : Builder, DslBuilder { - override var list: List? = null - - override fun build(): SparseListDeserializer = SparseListDeserializer(this) - } - } - - @Test - fun itHandlesSparseLists() { - val payload = """ - - - 1 - - 3 - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), SparseValues) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = - SparseListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = listOf(1, null, 3) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesEmptyLists() { - val payload = """ - - - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = emptyList() - - assertEquals(expected, actual) - } - - @Test - fun itHandlesFlatLists() { - val payload = """ - - 1 - 2 - 3 - - """.encodeToByteArray() - val elementFieldDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("element"), Flattened) - val objectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(elementFieldDescriptor) - } - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, objectDescriptor, elementFieldDescriptor).list - val expected = listOf(1, 2, 3) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesListOfObjectsWithMissingFields() { - val payload = """ - - - - a - b - - - - d - - - - """.encodeToByteArray() - val listWrapperFieldDescriptor = - SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), XmlCollectionName(element = "payload")) - val objectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(listWrapperFieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual: MutableList? = null - deserializer.deserializeStruct(objectDescriptor) { - loop@ while (true) { - when (findNextFieldIndex()) { - listWrapperFieldDescriptor.index -> - actual = - deserializer.deserializeList(listWrapperFieldDescriptor) { - val list = mutableListOf() - while (hasNextElement()) { - list.add(SimpleStructOfStringsClass.deserialize(deserializer)) - } - return@deserializeList list - } - null -> break@loop - else -> skipValue() - } - } - } - - assertEquals(2, actual!!.size) - - assertEquals("a", actual!![0].x) - assertEquals("b", actual!![0].y) - assertEquals("", actual!![1].x) - assertEquals("d", actual!![1].y) - } - - @Test - fun itHandlesListOfObjectsWithEmptyValues() { - val payload = """ - - - - 1 - 2 - - - - - """.encodeToByteArray() - val listWrapperFieldDescriptor = - SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), XmlCollectionName(element = "payload")) - val objectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(listWrapperFieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual: MutableList? = null - deserializer.deserializeStruct(objectDescriptor) { - loop@ while (true) { - when (findNextFieldIndex()) { - listWrapperFieldDescriptor.index -> - actual = - deserializer.deserializeList(listWrapperFieldDescriptor) { - val list = mutableListOf() - while (hasNextElement()) { - list.add(SimpleStructClass.deserialize(deserializer)) - } - return@deserializeList list - } - null -> break@loop - else -> skipValue() - } - } - } - assertEquals(2, actual!!.size) - assertEquals(1, actual!![0].x) - assertEquals(2, actual!![0].y) - assertEquals(null, actual!![1].x) - assertEquals(null, actual!![1].y) - } - - @Test - fun itHandlesNestedLists() { - val payload = """ - - - - - a - 3 - - - a - 3 - - - - - b - 4 - - - c - 5 - - - - - d - 8 - - - e - 9 - - - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val actual = NestedListOperationOperationDeserializer().deserialize(deserializer) - - assertTrue(actual.parentList?.size == 3) - } - - @Test - fun itHandlesListsOfStructs() { - val payload = """ - - - - a - 3 - - - b - 4 - - - c - 6 - - - - """.encodeToByteArray() - - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("parentList")) - val deserializer = XmlDeserializer(payload) - val actual = FooOperationDeserializer().deserialize(deserializer, listDescriptor) - - assertTrue(actual.parentList?.size == 3) - } - - @Test - fun itHandlesFlatListsOfStructs() { - val payload = """ - - - a - 3 - - - b - 4 - - - c - 6 - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("flatList"), Flattened) - val actual = FooOperationDeserializer().deserialize(deserializer, listDescriptor) - - val parentList = assertNotNull(actual.parentList) - assertEquals(3, parentList.size) - assertEquals(parentList[0].fooMember, "a") - assertEquals(parentList[0].someInt, 3) - assertEquals(parentList[2].fooMember, "c") - assertEquals(parentList[2].someInt, 6) - } -} - -internal class FooOperationDeserializer { - - fun deserialize( - deserializer: Deserializer, - LIST_DESCRIPTOR: SdkFieldDescriptor, - ): FooResponse { - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("FooResponse")) - field(LIST_DESCRIPTOR) - } - - val builder = FooResponse.Builder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - LIST_DESCRIPTOR.index -> - builder.parentList = - deserializer.deserializeList(LIST_DESCRIPTOR) { - val col0 = mutableListOf() - while (hasNextElement()) { - val el0 = if (nextHasValue()) { PayloadStructDocumentDeserializer().deserialize(deserializer) } else { deserializeNull(); continue } - col0.add(el0) - } - col0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} - -internal class PayloadStructDocumentDeserializer { - - companion object { - private val FOOMEMBER_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("fooMember")) - private val SOMEINT_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("someInt")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - field(FOOMEMBER_DESCRIPTOR) - field(SOMEINT_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): PayloadStruct { - val builder = PayloadStruct.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - FOOMEMBER_DESCRIPTOR.index -> builder.fooMember = deserializeString() - SOMEINT_DESCRIPTOR.index -> builder.someInt = deserializeInt() - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} - -class FooResponse private constructor(builder: Builder) { - val parentList: List? = builder.parentList - - companion object { - fun builder(): Builder = Builder() - - operator fun invoke(block: Builder.() -> kotlin.Unit): FooResponse = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("FooResponse(") - append("parentList=$parentList)") - } - - override fun hashCode(): kotlin.Int { - var result = parentList?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as FooResponse - - if (parentList != other.parentList) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): FooResponse = Builder(this).apply(block).build() - - public class Builder() { - var parentList: List? = null - - constructor(x: FooResponse) : this() { - this.parentList = x.parentList - } - - fun build(): FooResponse = FooResponse(this) - } -} - -class PayloadStruct private constructor(builder: BuilderImpl) { - val fooMember: String? = builder.fooMember - val someInt: Int? = builder.someInt - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): PayloadStruct = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("PayloadStruct(") - append("fooMember=$fooMember,") - append("someInt=$someInt)") - } - - override fun hashCode(): kotlin.Int { - var result = fooMember?.hashCode() ?: 0 - result = 31 * result + (someInt ?: 0) - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as PayloadStruct - - if (fooMember != other.fooMember) return false - if (someInt != other.someInt) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): PayloadStruct = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): PayloadStruct - fun fooMember(fooMember: String): Builder - fun someInt(someInt: Int): Builder - } - - interface DslBuilder { - var fooMember: String? - var someInt: Int? - - fun build(): PayloadStruct - } - - private class BuilderImpl() : Builder, DslBuilder { - override var fooMember: String? = null - override var someInt: Int? = null - - constructor(x: PayloadStruct) : this() { - this.fooMember = x.fooMember - this.someInt = x.someInt - } - - override fun build(): PayloadStruct = PayloadStruct(this) - override fun fooMember(fooMember: String): Builder = apply { this.fooMember = fooMember } - override fun someInt(someInt: Int): Builder = apply { this.someInt = someInt } - } -} - -class NestedListResponse private constructor(builder: BuilderImpl) { - val parentList: List>? = builder.parentList - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): NestedListResponse = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("NestedListResponse(") - append("parentList=$parentList)") - } - - override fun hashCode(): kotlin.Int { - var result = parentList?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as NestedListResponse - - if (parentList != other.parentList) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): NestedListResponse = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): NestedListResponse - fun parentList(parentList: List>): Builder - } - - interface DslBuilder { - var parentList: List>? - - fun build(): NestedListResponse - } - - private class BuilderImpl() : Builder, DslBuilder { - override var parentList: List>? = null - - constructor(x: NestedListResponse) : this() { - this.parentList = x.parentList - } - - override fun build(): NestedListResponse = NestedListResponse(this) - override fun parentList(parentList: List>): Builder = apply { this.parentList = parentList } - } -} - -internal class NestedListOperationOperationDeserializer { - - companion object { - private val PARENTLIST_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("parentList")) - private val PARENTLIST_C0_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("parentListC0")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("NestedListResponse")) - field(PARENTLIST_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): NestedListResponse { - val builder = NestedListResponse.dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - PARENTLIST_DESCRIPTOR.index -> - builder.parentList = - deserializer.deserializeList(PARENTLIST_DESCRIPTOR) { - val col0 = mutableListOf>() - while (hasNextElement()) { - val el0 = deserializer.deserializeList(PARENTLIST_C0_DESCRIPTOR) { - val col1 = mutableListOf() - while (hasNextElement()) { - val el1 = if (nextHasValue()) { PayloadStructDocumentDeserializer().deserialize(deserializer) } else { deserializeNull(); continue } - col1.add(el1) - } - col1 - } - col0.add(el0) - } - col0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt deleted file mode 100644 index d93a75b41..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt +++ /dev/null @@ -1,795 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import io.kotest.matchers.maps.shouldContainExactly -import kotlin.test.Test - -class XmlDeserializerMapTest { - - @Test - fun itHandlesMapsWithDefaultNodeNames() { - val payload = """ - - - - key1 - 1 - - - key2 - 2 - - - - """.encodeToByteArray() - val fieldDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - - var actual = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - val expected = mapOf("key1" to 1, "key2" to 2) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesMapsWithCustomNodeNames() { - val payload = """ - - - - key1 - 1 - - - key2 - 2 - - - - """.encodeToByteArray() - val fieldDescriptor = - SdkFieldDescriptor(SerialKind.Map, XmlSerialName("mymap"), XmlMapName("myentry", "mykey", "myvalue")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - var actual = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - val expected = mapOf("key1" to 1, "key2" to 2) - actual.shouldContainExactly(expected) - } - - // https://awslabs.github.io/smithy/1.0/spec/core/xml-traits.html#flattened-map-serialization - @Test - fun itHandlesFlatMaps() { - val payload = """ - - - key1 - 1 - - - key2 - 2 - - - key3 - 3 - - - """.encodeToByteArray() - val containerFieldDescriptor = - SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), XmlMapName(null, "key", "value"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(containerFieldDescriptor) - } - var actual = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - containerFieldDescriptor.index -> - actual = - deserializer.deserializeMap(containerFieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - val expected = mapOf("key1" to 1, "key2" to 2, "key3" to 3) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesEmptyMaps() { - val payload = """ - - - - """.encodeToByteArray() - val containerFieldDescriptor = - SdkFieldDescriptor(SerialKind.Map, XmlSerialName("Map")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(containerFieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf() - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - containerFieldDescriptor.index -> - actual = - deserializer.deserializeMap(containerFieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = emptyMap() - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesSparseMaps() { - val payload = """ - - - - key1 - 1 - - - key2 - - - - - """.encodeToByteArray() - val fieldDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf() - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map = mutableMapOf() - while (hasNextEntry()) { - val key = key() - val value = when (nextHasValue()) { - true -> deserializeInt() - false -> deserializeNull() - } - - map[key] = value - } - return@deserializeMap map - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = mapOf("key1" to 1, "key2" to null) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesCheckingMapValuesForNull() { - val payload = """ - - - - key1 - 1 - - - key2 - - - - - """.encodeToByteArray() - val fieldDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf() - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = mapOf("key1" to 1) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesNestedMap() { - val payload = """ - - - - outer1 - - - - inner1 - innerValue1 - - - inner2 - innerValue2 - - - - - - outer2 - - - - inner3 - innerValue3 - - - inner4 - innerValue4 - - - - - - - """.encodeToByteArray() - val ELEMENT_MAP_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map"), XmlMapName(entry = "outerEntry", value = "outerValue")) - val nestedMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("nestedMap"), XmlMapName(entry = "innerEntry", value = "innerValue")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_MAP_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf>() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - ELEMENT_MAP_FIELD_DESCRIPTOR.index -> - actual = - deserializer.deserializeMap(ELEMENT_MAP_FIELD_DESCRIPTOR) { - val map0 = mutableMapOf>() - while (hasNextEntry()) { - val k0 = key() - val v0 = deserializer.deserializeMap(nestedMapDescriptor) { - val map1 = mutableMapOf() - while (hasNextEntry()) { - val k1 = key() - val v1 = if (nextHasValue()) { deserializeString() } else { deserializeNull(); continue } - map1[k1] = v1 - } - map1 - } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = mapOf( - "outer1" to mapOf("inner1" to "innerValue1", "inner2" to "innerValue2"), - "outer2" to mapOf("inner3" to "innerValue3", "inner4" to "innerValue4"), - ) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesNestedStructAsValue() { - val payload = """ - - - - foo - - there - - - - baz - - bye - - - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val resp = XmlMapsOperationDeserializer().deserialize(deserializer) - - println(resp) - } - - // https://github.com/awslabs/aws-sdk-kotlin/issues/962 - @Test - fun itHandlesConsecutiveFlatMaps() { - val payload = """ - - - key1 - 1 - - - key2 - 2 - - - key3 - 3 - - - key4 - 4 - - - key5 - 5 - - - """.encodeToByteArray() - val firstMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("firstMap"), XmlMapName(null, "key", "value"), Flattened) - val secondMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("secondMap"), XmlMapName(null, "key", "value"), Flattened) - - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(firstMapDescriptor) - field(secondMapDescriptor) - } - var firstMap = mutableMapOf() - var secondMap = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - firstMapDescriptor.index -> - firstMap = - deserializer.deserializeMap(firstMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - secondMapDescriptor.index -> - secondMap = - deserializer.deserializeMap(secondMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expectedFirstMap = mapOf("key1" to 1, "key2" to 2, "key3" to 3) - firstMap.shouldContainExactly(expectedFirstMap) - val expectedSecondMap = mapOf("key4" to 4, "key5" to 5) - secondMap.shouldContainExactly(expectedSecondMap) - } - - @Test - fun itHandlesMapsFollowedByFlatMaps() { - val payload = """ - - - - key1 - 1 - - - key2 - 2 - - - - key3 - 3 - - - key4 - 4 - - - """.encodeToByteArray() - val mapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map")) - val flatMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(mapDescriptor) - field(flatMapDescriptor) - } - - var map = mutableMapOf() - var flatMap = mutableMapOf() - - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - mapDescriptor.index -> - map = - deserializer.deserializeMap(mapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - flatMapDescriptor.index -> - flatMap = - deserializer.deserializeMap(flatMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - map.shouldContainExactly(mapOf("key1" to 1, "key2" to 2)) - flatMap.shouldContainExactly(mapOf("key3" to 3, "key4" to 4)) - } - - @Test - fun itHandlesFlatMapsFollowedByMaps() { - val payload = """ - - - key3 - 3 - - - key4 - 4 - - - - key1 - 1 - - - key2 - 2 - - - - """.encodeToByteArray() - val mapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map")) - val flatMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(mapDescriptor) - field(flatMapDescriptor) - } - - var map = mutableMapOf() - var flatMap = mutableMapOf() - - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - mapDescriptor.index -> - map = - deserializer.deserializeMap(mapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - flatMapDescriptor.index -> - flatMap = - deserializer.deserializeMap(flatMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - map.shouldContainExactly(mapOf("key1" to 1, "key2" to 2)) - flatMap.shouldContainExactly(mapOf("key3" to 3, "key4" to 4)) - } -} - -internal class XmlMapsOperationDeserializer() { - - companion object { - private val MYMAP_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("myMap")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("XmlMapsInputOutput")) - field(MYMAP_DESCRIPTOR) - } - } - - fun deserialize(deserializer: XmlDeserializer): XmlMapsInputOutput { - val builder = XmlMapsInputOutput.dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - MYMAP_DESCRIPTOR.index -> - builder.myMap = - deserializer.deserializeMap(MYMAP_DESCRIPTOR) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { GreetingStructDocumentDeserializer().deserialize(deserializer) } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} - -internal class GreetingStructDocumentDeserializer { - - companion object { - private val HI_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("hi")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("GreetingStruct")) - field(HI_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): GreetingStruct { - val builder = GreetingStruct.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - HI_DESCRIPTOR.index -> builder.hi = deserializeString() - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} - -class XmlMapsInputOutput private constructor(builder: BuilderImpl) { - val myMap: Map? = builder.myMap - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): XmlMapsInputOutput = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("XmlMapsInputOutput(") - append("myMap=$myMap)") - } - - override fun hashCode(): kotlin.Int { - var result = myMap?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as XmlMapsInputOutput - - if (myMap != other.myMap) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): XmlMapsInputOutput = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): XmlMapsInputOutput - fun myMap(myMap: Map): Builder - } - - interface DslBuilder { - var myMap: Map? - - fun build(): XmlMapsInputOutput - } - - private class BuilderImpl() : Builder, DslBuilder { - override var myMap: Map? = null - - constructor(x: XmlMapsInputOutput) : this() { - this.myMap = x.myMap - } - - override fun build(): XmlMapsInputOutput = XmlMapsInputOutput(this) - override fun myMap(myMap: Map): Builder = apply { this.myMap = myMap } - } -} - -class GreetingStruct private constructor(builder: BuilderImpl) { - val hi: String? = builder.hi - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): GreetingStruct = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("GreetingStruct(") - append("hi=$hi)") - } - - override fun hashCode(): kotlin.Int { - var result = hi?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as GreetingStruct - - if (hi != other.hi) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): GreetingStruct = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): GreetingStruct - fun hi(hi: String): Builder - } - - interface DslBuilder { - var hi: String? - - fun build(): GreetingStruct - } - - private class BuilderImpl() : Builder, DslBuilder { - override var hi: String? = null - - constructor(x: GreetingStruct) : this() { - this.hi = x.hi - } - - override fun build(): GreetingStruct = GreetingStruct(this) - override fun hi(hi: String): Builder = apply { this.hi = hi } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt deleted file mode 100644 index becb7e347..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertEquals - -// See https://awslabs.github.io/smithy/spec/xml.html#xmlname-trait -class XmlDeserializerNamespaceTest { - - @Test - fun `it handles struct with namespace declarations but default tags`() { - val payload = """ - - example1 - example2 - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = NamespaceStructTest.deserialize(deserializer) - - assertEquals("example1", bst.foo) - assertEquals("example2", bst.bar) - } - - class NamespaceStructTest { - var foo: String? = null - var bar: String? = null - - companion object { - val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("foo")) - val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("bar")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com")) - field(FOO_DESCRIPTOR) - field(BAR_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): NamespaceStructTest { - val result = NamespaceStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - FOO_DESCRIPTOR.index -> result.foo = deserializeString() - BAR_DESCRIPTOR.index -> result.bar = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } - } - - @Test - fun `it handles struct with node namespace`() { - val payload = """ - - example1 - example2 - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = NodeNamespaceStructTest.deserialize(deserializer) - - assertEquals("example1", bst.foo) - assertEquals("example2", bst.bar) - } - - class NodeNamespaceStructTest { - var foo: String? = null - var bar: String? = null - - companion object { - val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("foo")) - val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("baz:bar")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com", "baz")) - field(FOO_DESCRIPTOR) - field(BAR_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): NodeNamespaceStructTest { - val result = NodeNamespaceStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - FOO_DESCRIPTOR.index -> result.foo = deserializeString() - BAR_DESCRIPTOR.index -> result.bar = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt deleted file mode 100644 index 26e02418f..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.math.abs -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertTrue - -class XmlDeserializerPrimitiveTest { - @Test - fun itHandlesDoubles() { - val deserializer = XmlPrimitiveDeserializer("1.2".wrapInStruct(), SdkFieldDescriptor(SerialKind.Double, XmlSerialName("node"))) - val actual = deserializer.deserializeDouble() - val expected = 1.2 - assertTrue(abs(actual - expected) <= 0.0001) - } - - @Test - fun itHandlesFloats() { - val deserializer = XmlPrimitiveDeserializer("1.2".wrapInStruct(), SdkFieldDescriptor(SerialKind.Float, XmlSerialName("node"))) - val actual = deserializer.deserializeFloat() - val expected = 1.2f - assertTrue(abs(actual - expected) <= 0.0001f) - } - - @Test - fun itHandlesInt() { - val deserializer = XmlPrimitiveDeserializer("${Int.MAX_VALUE}".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - val actual = deserializer.deserializeInt() - val expected = 2147483647 - assertEquals(expected, actual) - } - - @Test - fun itHandlesByteAsNumber() { - val deserializer = XmlPrimitiveDeserializer("1".wrapInStruct(), SdkFieldDescriptor(SerialKind.Byte, XmlSerialName("node"))) - val actual = deserializer.deserializeByte() - val expected: Byte = 1 - assertEquals(expected, actual) - } - - @Test - fun itHandlesShort() { - val deserializer = XmlPrimitiveDeserializer("${Short.MAX_VALUE}".wrapInStruct(), SdkFieldDescriptor(SerialKind.Short, XmlSerialName("node"))) - val actual = deserializer.deserializeShort() - val expected: Short = 32767 - assertEquals(expected, actual) - } - - @Test - fun itHandlesLong() { - val deserializer = XmlPrimitiveDeserializer("${Long.MAX_VALUE}".wrapInStruct(), SdkFieldDescriptor(SerialKind.Long, XmlSerialName("node"))) - val actual = deserializer.deserializeLong() - val expected = 9223372036854775807L - assertEquals(expected, actual) - } - - @Test - fun itHandlesBool() { - val deserializer = XmlPrimitiveDeserializer("true".wrapInStruct(), SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("node"))) - val actual = deserializer.deserializeBoolean() - assertTrue(actual) - } - - @Test - fun itFailsInvalidTypeSpecificationForInt() { - val deserializer = XmlPrimitiveDeserializer("1.2".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - assertFailsWith(DeserializationException::class) { - deserializer.deserializeInt() - } - } - - // TODO: It's unclear if this test should result in an exception or null value. - @Test - fun itFailsMissingTypeSpecificationForInt() { - val deserializer = XmlPrimitiveDeserializer("".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - assertFailsWith(DeserializationException::class) { - deserializer.deserializeInt() - } - } - - // TODO: It's unclear if this test should result in an exception or null value. - @Test - fun itFailsWhitespaceTypeSpecificationForInt() { - val deserializer = XmlPrimitiveDeserializer(" ".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - assertFailsWith(DeserializationException::class) { - deserializer.deserializeInt() - } - } - - private fun String.wrapInStruct(): ByteArray = "$this".encodeToByteArray() -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt deleted file mode 100644 index a500f0d66..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertEquals - -class XmlDeserializerStructTest { - @Test - fun `it handles basic structs with attribs`() { - val payload = """ - - - - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = StructWithAttribsClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - } - - @Test - fun `it handles basic structs with multi attribs and text`() { - val payload = """ - - - - - nodeval - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = StructWithMultiAttribsAndTextValClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - assertEquals("nodeval", bst.txt) - } - - @Test - fun itHandlesBasicStructsWithAttribsAndText() { - val payload = """ - - x1 - - true - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = BasicAttribTextStructTest.deserialize(deserializer) - - assertEquals(1, bst.xa) - assertEquals("x1", bst.xt) - assertEquals(2, bst.y) - assertEquals(1, bst.unknownFieldCount) - } - - class BasicAttribTextStructTest { - var xa: Int? = null - var xt: String? = null - var y: Int? = null - var z: Boolean? = null - var unknownFieldCount: Int = 0 - - companion object { - val X_ATTRIB_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("xa"), XmlAttribute) - val X_VALUE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("ya"), XmlAttribute) - val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("z")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_ATTRIB_DESCRIPTOR) - field(X_VALUE_DESCRIPTOR) - field(Y_DESCRIPTOR) - field(Z_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): BasicAttribTextStructTest { - val result = BasicAttribTextStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_ATTRIB_DESCRIPTOR.index -> result.xa = deserializeInt() - X_VALUE_DESCRIPTOR.index -> result.xt = deserializeString() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - Z_DESCRIPTOR.index -> result.z = deserializeBoolean() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - result.unknownFieldCount++ - skipValue() - } - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } - } - - @Test - fun itHandlesBasicStructs() { - val payload = """ - - 1 - 2 - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = SimpleStructClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - } - - @Test - fun itHandlesBasicStructsWithNullValues() { - val payload1 = """ - - a - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload1) - val bst = SimpleStructOfStringsClass.deserialize(deserializer) - - assertEquals("a", bst.x) - assertEquals("", bst.y) - - val payload2 = """ - - - 2 - - """.encodeToByteArray() - - val deserializer2 = XmlDeserializer(payload2) - val bst2 = SimpleStructOfStringsClass.deserialize(deserializer2) - - assertEquals("", bst2.x) - assertEquals("2", bst2.y) - } - - @Test - fun itEnumeratesUnknownStructFields() { - val payload = """ - - 1 - 2 - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = SimpleStructClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - assertEquals("strval", bst.z) - } - - @Test - fun itHandlesNestedXmlStructures() { - val payload = """ - - - Foo1 - - Bar1 - - Foo2 - - Bar2 - - - - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = RecursiveShapesOperationDeserializer().deserialize(deserializer) - - println(bst.nested?.nested) - } - - class BasicUnwrappedStructTest { - var x: String? = null - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("x")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - trait(XmlUnwrappedOutput) - field(X_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): BasicUnwrappedStructTest { - val result = BasicUnwrappedStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicUnwrappedStructTest deserializer")) - } - } - } - return result - } - } - } - - @Test - fun itHandlesBasicUnwrappedStructs() { - val payload = """ - text - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = BasicUnwrappedStructTest.deserialize(deserializer) - - assertEquals("text", bst.x) - } - - @Test - fun itHandlesBasicUnwrappedStructsWithNullValues() { - val payload = """ - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = BasicUnwrappedStructTest.deserialize(deserializer) - - assertEquals(null, bst.x) - } - - class AliasStruct { - var message: String? = null - var attribute: String? = null - - companion object { - val MESSAGE_DESCRIPTOR = SdkFieldDescriptor( - SerialKind.String, - XmlSerialName("Message"), - XmlAliasName("message"), - XmlAliasName("msg"), - ) - val ATTRIBUTE_DESCRIPTOR = SdkFieldDescriptor( - SerialKind.String, - XmlAttribute, - XmlSerialName("Attribute"), - XmlAliasName("attribute"), - XmlAliasName("attr"), - ) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Struct")) - trait(XmlAliasName("struct")) - field(MESSAGE_DESCRIPTOR) - field(ATTRIBUTE_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): AliasStruct { - val result = AliasStruct() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - MESSAGE_DESCRIPTOR.index -> result.message = deserializeString() - ATTRIBUTE_DESCRIPTOR.index -> result.attribute = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in AliasStruct deserializer")) - } - } - } - return result - } - } - } - - @Test - fun itHandlesAliasMatchingOnElements() { - val tests = listOf( - "Hi there", - "Hi there", - "Hi there", - "Hi there", - ) - tests.forEach { payload -> - val deserializer = XmlDeserializer(payload.encodeToByteArray()) - val bst = AliasStruct.deserialize(deserializer) - assertEquals("Hi there", bst.message, "Can't find 'Hi there' in $payload") - } - } - - @Test - fun itHandlesAliasMatchingOnAttributes() { - val tests = listOf( - """""", - """""", - """""", - ) - tests.forEach { payload -> - val deserializer = XmlDeserializer(payload.encodeToByteArray()) - val bst = AliasStruct.deserialize(deserializer) - assertEquals("Hi there", bst.attribute, "Can't find 'Hi there' in $payload") - } - } -} - -internal class RecursiveShapesOperationDeserializer { - - companion object { - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutput")) - field(NESTED_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): RecursiveShapesInputOutput { - val builder = RecursiveShapesInputOutput.Builder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - NESTED_DESCRIPTOR.index -> builder.nested = RecursiveShapesInputOutputNested1DocumentDeserializer().deserialize(deserializer) - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} - -internal class RecursiveShapesInputOutputNested1DocumentDeserializer { - - companion object { - private val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo")) - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - field(FOO_DESCRIPTOR) - field(NESTED_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): RecursiveShapesInputOutputNested1 { - val builder = RecursiveShapesInputOutputNested1.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - FOO_DESCRIPTOR.index -> builder.foo = deserializeString() - NESTED_DESCRIPTOR.index -> builder.nested = RecursiveShapesInputOutputNested2DocumentDeserializer().deserialize(deserializer) - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} - -internal class RecursiveShapesInputOutputNested2DocumentDeserializer { - - companion object { - private val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("bar")) - private val RECURSIVEMEMBER_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("recursiveMember")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - field(BAR_DESCRIPTOR) - field(RECURSIVEMEMBER_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): RecursiveShapesInputOutputNested2 { - val builder = RecursiveShapesInputOutputNested2.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - BAR_DESCRIPTOR.index -> builder.bar = deserializeString() - RECURSIVEMEMBER_DESCRIPTOR.index -> builder.recursiveMember = RecursiveShapesInputOutputNested1DocumentDeserializer().deserialize(deserializer) - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt deleted file mode 100644 index ad60d281c..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt +++ /dev/null @@ -1,963 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertEquals - -/* -Remove all whitespace and newline chars from XML string and return the compact form -e.g. - -``` - - - 1 - - -``` - -becomes: `1` - */ -private fun String.toXmlCompactString(): String = - trimIndent() - .replace("\n", "") - .replace(Regex(">\\s+"), ">") - -class XmlSerializerTest { - - @Test - fun canSerializeClassWithClassField() { - val a = A( - B(2), - ) - val xml = XmlSerializer() - a.serialize(xml) - assertEquals("""2""", xml.toByteArray().decodeToString()) - } - - class A(private val b: B) : SdkSerializable { - companion object { - val descriptorB: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("b")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("a")) - field(descriptorB) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(descriptorB, b) - } - } - } - - data class B(private val value: Int) : SdkSerializable { - companion object { - val descriptorValue = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("v")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("b")) - field(descriptorValue) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(descriptorValue, value) - } - } - } - - @Test - fun canSerializePrimitiveList() { - // https://awslabs.github.io/smithy/spec/xml.html#wrapped-list-serialization - val list = listOf("example1", "example2", "example3") - val xml = XmlSerializer() - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(listDescriptor) - } - - xml.serializeStruct(objDescriptor) { - listField(listDescriptor) { - for (value in list) { - serializeString(value) - } - } - } - - val expected = """ - - - example1 - example2 - example3 - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeRenamedList() { - val list = listOf("example1", "example2", "example3") - val xml = XmlSerializer() - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("values"), XmlCollectionName("Item")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(listDescriptor) - } - - xml.serializeStruct(objDescriptor) { - listField(listDescriptor) { - for (value in list) { - serializeString(value) - } - } - } - - val expected = """ - - - example1 - example2 - example3 - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeFlattenedList() { - // https://awslabs.github.io/smithy/spec/xml.html#flattened-list-serialization - val list = listOf("example1", "example2", "example3") - val xml = XmlSerializer() - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("flat"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(listDescriptor) - } - - xml.serializeStruct(objDescriptor) { - listField(listDescriptor) { - for (value in list) { - serializeString(value) - } - } - } - - val expected = """ - - example1 - example2 - example3 - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeListOfClasses() { - val obj = listOf( - B(1), - B(2), - B(3), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"))) { - for (value in obj) { - serializeSdkSerializable(value) - } - } - - val expected = """ - - - 1 - - - 2 - - - 3 - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeFlatListOfClasses() { - val obj = listOf( - B(1), - B(2), - B(3), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), Flattened)) { - for (value in obj) { - serializeSdkSerializable(value) - } - } - val expected = """ - - 1 - - - 2 - - - 3 - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMap() { - // See https://awslabs.github.io/smithy/spec/xml.html#wrapped-map-serialization - val foo = Foo( - mapOf( - "example-key1" to "example1", - "example-key2" to "example2", - ), - ) - val xml = XmlSerializer() - foo.serialize(xml) - - val expected = """ - - - - example-key1 - example1 - - - example-key2 - example2 - - - - """.toXmlCompactString() - - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeFlattenedMap() { - // See https://awslabs.github.io/smithy/spec/xml.html#flattened-map-serialization - val bar = Bar( - mapOf( - "example-key1" to "example1", - "example-key2" to "example2", - "example-key3" to "example3", - ), - ) - val serializer = XmlSerializer() - bar.serialize(serializer) - - val expected = """ - - - example-key1 - example1 - - - - example-key2 - example2 - - - - example-key3 - example3 - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMapOfLists() { - val objs = mapOf( - "A1" to listOf("a", "b", "c"), - "A2" to listOf("d", "e", "f"), - "A3" to listOf("g", "h", "i"), - ) - val xml = XmlSerializer() - xml.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("objs"))) { - for (obj in objs) { - listEntry(obj.key, SdkFieldDescriptor(SerialKind.List, XmlSerialName("elements"))) { - for (v in obj.value) { - serializeString(v) - } - } - } - } - - val expected = """ - - - A1 - - - a - b - c - - - - - A2 - - - d - e - f - - - - - A3 - - - g - h - i - - - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeListOfLists() { - val objs = listOf( - listOf("a", "b", "c"), - listOf("d", "e", "f"), - listOf("g", "h", "i"), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("objs"))) { - for (obj in objs) { - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("elements"))) { - for (v in obj) { - serializeString(v) - } - } - } - } - - val expected = """ - - - a - b - c - - - d - e - f - - - g - h - i - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeListOfMaps() { - val objs = listOf( - mapOf("a" to "b", "c" to "d"), - mapOf("e" to "f", "g" to "h"), - mapOf("i" to "j", "k" to "l"), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("elements"))) { - for (obj in objs) { - xml.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("entries"))) { - for (v in obj) { - entry(v.key, v.value) - } - } - } - } - val expected = """ - - - - a - b - - - c - d - - - - - e - f - - - g - h - - - - - i - j - - - k - l - - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMapOfMaps() { - val objs = mapOf( - "A1" to mapOf("a" to "b", "c" to "d"), - "A2" to mapOf("e" to "f", "g" to "h"), - "A3" to mapOf("i" to "j", "k" to "l"), - ) - val serializer = XmlSerializer() - serializer.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("objs"))) { - for (obj in objs) { - mapEntry(obj.key, SdkFieldDescriptor(SerialKind.Map)) { - for (v in obj.value) { - entry(v.key, v.value) - } - } - } - } - - // NOTE the child map entries do not have a surrounding tag around them, much like a map of structs omit the - // structure tag - val expected = """ - - - A1 - - - a - b - - - c - d - - - - - A2 - - - e - f - - - g - h - - - - - A3 - - - i - j - - - k - l - - - - - """.toXmlCompactString() - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMapOfStructs() { - val objs = mapOf( - "foo" to B(1), - "bar" to B(2), - ) - - val serializer = XmlSerializer() - - serializer.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("myMap"))) { - objs.entries.forEach { (key, value) -> entry(key, value) } - } - - val expected = """ - - - foo - - 1 - - - - bar - - 2 - - - - """.toXmlCompactString() - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - class Bar(var flatMap: Map? = null) : SdkSerializable { - companion object { - val FLAT_MAP_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), XmlMapName(entry = "flatMap"), Flattened) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Bar")) - field(FLAT_MAP_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - mapField(FLAT_MAP_DESCRIPTOR) { - for (value in flatMap!!) { - entry(value.key, value.value) - } - } - } - } - } - - class Foo(var values: Map? = null) : SdkSerializable { - companion object { - val FLAT_MAP_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values"), XmlMapName(entry = "entry")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(FLAT_MAP_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - mapField(FLAT_MAP_DESCRIPTOR) { - for (value in values!!) { - entry(value.key, value.value) - } - } - } - } - } - - @Test - fun canSerializeAllPrimitives() { - val xml = XmlSerializer() - val data = Primitives( - true, 10, 20, 30, 40, 50f, 60.0, 'A', "Str0", - listOf(1, 2, 3), - ) - data.serialize(xml) - - assertEquals("""true1020304050.060.0AStr0123""", xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeNamespaces() { - // See https://awslabs.github.io/smithy/spec/xml.html#xmlnamespace-trait - val myStructure = MyStructure1("example", "example") - val xml = XmlSerializer() - myStructure.serialize(xml) - val expected1 = """ - - example - example - - """.toXmlCompactString() - assertEquals(expected1, xml.toByteArray().decodeToString()) - - val myStructure2 = MyStructure2("example", "example") - val xml2 = XmlSerializer() - myStructure2.serialize(xml2) - val expected2 = """ - - example - example - - """.toXmlCompactString() - assertEquals(expected2, xml2.toByteArray().decodeToString()) - } - - @Test - fun canSerializeNestedNamespaces() { - val input = XmlNamespacesRequest( - nested = XmlNamespaceNested( - foo = "Foo", - values = listOf("Bar", "Baz"), - ), - ) - - val serializer = XmlSerializer() - input.serialize(serializer) - - val expected = """ - - - Foo - - Bar - Baz - - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - class MyStructure1(private val foo: String, private val bar: String) : SdkSerializable { - companion object { - val fooDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("foo")) - val barDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("bar")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com")) - field(fooDescriptor) - field(barDescriptor) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(fooDescriptor, foo) - field(barDescriptor, bar) - } - } - } - - class MyStructure2(private val foo: String, private val bar: String) : SdkSerializable { - companion object { - val fooDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("foo")) - val barDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("baz:bar")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com", "baz")) - - field(fooDescriptor) - field(barDescriptor) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(fooDescriptor, foo) - field(barDescriptor, bar) - } - } - } - - @Test - fun canIgnoresNestedStructNamespaces() { - /* - @xmlNamespace(uri: "http://foo.com") - structure Foo { - nested: Bar, - } - - // Ignored - not at top level - // TODO - nothing in the spec defines this...only the protocol tests - @xmlNamespace(uri: "http://bar.com") - structure Bar { - x: String - } - */ - - val serializer = XmlSerializer() - val nestedDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - trait(XmlNamespace("http://foo.com")) - field(nestedDescriptor) - } - - val nested = object : SdkSerializable { - override fun serialize(serializer: Serializer) { - val xDescriptor = SdkFieldDescriptor(SerialKind.String, XmlSerialName("x")) - val obj2Descriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Bar")) - trait(XmlNamespace("http://bar.com")) - field(xDescriptor) - } - serializer.serializeStruct(obj2Descriptor) { - field(xDescriptor, "blerg") - } - } - } - - serializer.serializeStruct(objDescriptor) { - field(nestedDescriptor, nested) - } - - val expected = """ - - - blerg - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun itSerializesRecursiveShapes() { - val expected = """ - - - Foo1 - - Bar1 - - Foo2 - - Bar2 - - - - - - """.toXmlCompactString() - - val input = RecursiveShapesInputOutput { - nested = RecursiveShapesInputOutputNested1 { - foo = "Foo1" - nested = RecursiveShapesInputOutputNested2 { - bar = "Bar1" - recursiveMember = RecursiveShapesInputOutputNested1 { - foo = "Foo2" - nested = RecursiveShapesInputOutputNested2 { - bar = "Bar2" - } - } - } - } - } - - val serializer = XmlSerializer() - RecursiveShapesInputOutputSerializer().serialize(serializer, input) - val actual = serializer.toByteArray().decodeToString() - println(actual) - assertEquals(expected, actual) - } - - @Test - fun itCanSerializeAttributes() { - val boolDescriptor = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("bool"), XmlAttribute) - val strDescriptor = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("str"), XmlAttribute) - val intDescriptor = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("number"), XmlAttribute) - // timestamps are ignored as they aren't special cased (as of right now) but rather serialized through string/raw - - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(boolDescriptor) - field(strDescriptor) - field(intDescriptor) - } - - // NOTE: attribute fields MUST be generated as the first fields after serializeStruct() to work properly - val serializer = XmlSerializer() - serializer.serializeStruct(objDescriptor) { - field(boolDescriptor, true) - field(strDescriptor, "bar") - field(intDescriptor, 2) - } - - val expected = """ - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun itCanSerializeAttributesWithNamespaces() { - val nestedDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nestedField"), XmlNamespace("https://example.com", "xsi")) - - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(nestedDescriptor) - } - - val nestedSerializer = object : SdkSerializable { - override fun serialize(serializer: Serializer) { - val attrDescriptor = SdkFieldDescriptor(SerialKind.String, XmlSerialName("xsi:myAttr"), XmlAttribute) - val nestedObjDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Nested")) - field(attrDescriptor) - } - serializer.serializeStruct(nestedObjDescriptor) { - field(attrDescriptor, "nestedAttrValue") - } - } - } - - // NOTE: attribute fields MUST be generated as the first fields after serializeStruct() to work properly - val serializer = XmlSerializer() - serializer.serializeStruct(objDescriptor) { - field(nestedDescriptor, nestedSerializer) - } - - // The order these attributes come out in exactly the order they're put in (as defined by XmlSerializer). - val expected = """ - - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } -} - -data class Primitives( - // val unit: Unit, - val boolean: Boolean, - val byte: Byte, - val short: Short, - val int: Int, - val long: Long, - val float: Float, - val double: Double, - val char: Char, - val string: String, - // val unitNullable: Unit?, - val listInt: List, -) : SdkSerializable { - companion object { - val descriptorBoolean = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("boolean")) - val descriptorByte = SdkFieldDescriptor(SerialKind.Byte, XmlSerialName("byte")) - val descriptorShort = SdkFieldDescriptor(SerialKind.Short, XmlSerialName("short")) - val descriptorInt = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("int")) - val descriptorLong = SdkFieldDescriptor(SerialKind.Long, XmlSerialName("long")) - val descriptorFloat = SdkFieldDescriptor(SerialKind.Float, XmlSerialName("float")) - val descriptorDouble = SdkFieldDescriptor(SerialKind.Double, XmlSerialName("double")) - val descriptorChar = SdkFieldDescriptor(SerialKind.Char, XmlSerialName("char")) - val descriptorString = SdkFieldDescriptor(SerialKind.String, XmlSerialName("string")) - - // val descriptorUnitNullable = SdkFieldDescriptor("unitNullable") - val descriptorListInt = SdkFieldDescriptor(SerialKind.List, XmlSerialName("listInt"), XmlCollectionName(element = "number")) - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("struct"))) { - serializeNull() - field(descriptorBoolean, boolean) - field(descriptorByte, byte) - field(descriptorShort, short) - field(descriptorInt, int) - field(descriptorLong, long) - field(descriptorFloat, float) - field(descriptorDouble, double) - field(descriptorChar, char) - field(descriptorString, string) - // serializeNull(descriptorUnitNullable) - listField(descriptorListInt) { - for (value in listInt) { - serializeInt(value) - } - } - } - } -} - -// structure RecursiveShapesInputOutput { -// nested: RecursiveShapesInputOutputNested1 -// } -// -// structure RecursiveShapesInputOutputNested1 { -// foo: String, -// nested: RecursiveShapesInputOutputNested2 -// } -// -// structure RecursiveShapesInputOutputNested2 { -// bar: String, -// recursiveMember: RecursiveShapesInputOutputNested1, -// } -internal class RecursiveShapesInputOutputSerializer { - companion object { - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutput")) - field(NESTED_DESCRIPTOR) - } - } - - fun serialize(serializer: Serializer, input: RecursiveShapesInputOutput) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.nested?.let { field(NESTED_DESCRIPTOR, RecursiveShapesInputOutputNested1DocumentSerializer(it)) } - } - } -} - -internal class RecursiveShapesInputOutputNested1DocumentSerializer(val input: RecursiveShapesInputOutputNested1) : SdkSerializable { - - companion object { - private val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo")) - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutputNested1")) - field(FOO_DESCRIPTOR) - field(NESTED_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.foo?.let { field(FOO_DESCRIPTOR, it) } - input.nested?.let { field(NESTED_DESCRIPTOR, RecursiveShapesInputOutputNested2DocumentSerializer(it)) } - } - } -} - -internal class RecursiveShapesInputOutputNested2DocumentSerializer(val input: RecursiveShapesInputOutputNested2) : SdkSerializable { - - companion object { - private val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("bar")) - private val RECURSIVEMEMBER_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("recursiveMember")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutputNested2")) - field(BAR_DESCRIPTOR) - field(RECURSIVEMEMBER_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.bar?.let { field(BAR_DESCRIPTOR, it) } - input.recursiveMember?.let { field(RECURSIVEMEMBER_DESCRIPTOR, RecursiveShapesInputOutputNested1DocumentSerializer(it)) } - } - } -} diff --git a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt index 90d4a6c5e..0f01ee165 100644 --- a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt +++ b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt @@ -7,7 +7,8 @@ package aws.smithy.kotlin.benchmarks.serde.xml import aws.smithy.kotlin.benchmarks.serde.BenchmarkBase import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.model.CountriesAndStates import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.deserializeCountriesAndStatesDocument -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer +import aws.smithy.kotlin.runtime.serde.xml.root +import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader import kotlinx.benchmark.* import kotlinx.coroutines.runBlocking @@ -18,7 +19,7 @@ open class XmlDeserializerBenchmark : BenchmarkBase() { private fun deserialize(): CountriesAndStates = runBlocking { - val deserializer = XmlDeserializer(source) + val deserializer = xmlStreamReader(source).root() deserializeCountriesAndStatesDocument(deserializer) } diff --git a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt index e9a819446..dbd6c94bf 100644 --- a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt +++ b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt @@ -8,8 +8,9 @@ import aws.smithy.kotlin.benchmarks.serde.BenchmarkBase import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.model.CountriesAndStates import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.deserializeCountriesAndStatesDocument import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.serializeCountriesAndStatesDocument -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer +import aws.smithy.kotlin.runtime.serde.xml.root +import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader import kotlinx.benchmark.* import kotlinx.coroutines.runBlocking @@ -21,7 +22,7 @@ open class XmlSerializerBenchmark : BenchmarkBase() { @Setup fun init() { dataSet = runBlocking { - val deserializer = XmlDeserializer(source) + val deserializer = xmlStreamReader(source).root() deserializeCountriesAndStatesDocument(deserializer) } } From c9214ea37c931e78cc5fb9b8ae6b5e7b76875482 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Sat, 24 Feb 2024 21:24:33 -0500 Subject: [PATCH 14/25] remove unused fun --- .../codegen/rendering/serde/XmlParserGenerator.kt | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index dac28c06c..b223885d7 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex import software.amazon.smithy.kotlin.codegen.model.traits.UnwrappedXmlOutput import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator -import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext import software.amazon.smithy.model.shapes.* import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.TimestampFormatTrait @@ -28,8 +27,6 @@ import software.amazon.smithy.utils.StringUtils * XML parser generator based on common deserializer interface and XML serde descriptors */ open class XmlParserGenerator( - // FIXME - shouldn't be necessary but XML serde descriptor generator needs it for rendering context - private val protocolGenerator: ProtocolGenerator, private val defaultTimestampFormat: TimestampFormatTrait.Format, ) : StructuredDataParserGenerator { @@ -41,14 +38,6 @@ open class XmlParserGenerator( val tagReader: String, ) - // FIXME - remove - open fun descriptorGenerator( - ctx: ProtocolGenerator.GenerationContext, - shape: Shape, - members: List, - writer: KotlinWriter, - ): XmlSerdeDescriptorGenerator = XmlSerdeDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members) - override fun operationDeserializer( ctx: ProtocolGenerator.GenerationContext, op: OperationShape, From d4bd423c49ac9458d01d6889310a45baee772434 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Sat, 24 Feb 2024 22:31:35 -0500 Subject: [PATCH 15/25] cleanup and renames --- .../kotlin/codegen/core/RuntimeTypes.kt | 2 +- .../rendering/serde/XmlParserGenerator.kt | 12 ++-- .../xml/Ec2QueryErrorDeserializer.kt | 16 ++--- .../xml/RestXmlErrorDeserializer.kt | 19 +++-- .../xml/{TagReader.kt => XmlTagReader.kt} | 71 +++++++++---------- .../{TagReaderTest.kt => XmlTagReaderTest.kt} | 26 +++---- .../serde/xml/XmlDeserializerBenchmark.kt | 5 +- .../serde/xml/XmlSerializerBenchmark.kt | 5 +- .../xml/SerdeXmlProtocolGenerator.kt | 2 +- .../kotlin/tests/serde/AbstractXmlTest.kt | 4 +- .../smithy/kotlin/tests/serde/XmlListTest.kt | 5 +- .../smithy/kotlin/tests/serde/XmlMapTest.kt | 5 +- 12 files changed, 81 insertions(+), 91 deletions(-) rename runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/{TagReader.kt => XmlTagReader.kt} (65%) rename runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/{TagReaderTest.kt => XmlTagReaderTest.kt} (84%) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index b21747fee..02f6ce3d9 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -278,7 +278,7 @@ object RuntimeTypes { val TagReader = symbol("TagReader") val xmlStreamReader = symbol("xmlStreamReader") - val root = symbol("root") + val xmlTagReader = symbol("xmlTagReader") val data = symbol("data") val tryData = symbol("tryData") } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index b223885d7..ce11e9469 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -76,7 +76,7 @@ open class XmlParserGenerator( documentMembers: List, writer: KotlinWriter, ) { - writer.write("val root = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) + writer.write("val root = #T(payload)", SerdeXml.xmlTagReader) val shape = ctx.model.expectShape(op.output.get()) val serdeCtx = unwrapOperationBody(ctx, SerdeCtx("root"), op, writer) @@ -148,7 +148,7 @@ open class XmlParserGenerator( val fnName = symbol.errorDeserializerName() writer.openBlock("internal fun #L(builder: #T.Builder, payload: ByteArray) {", fnName, symbol) .call { - writer.write("val root = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) + writer.write("val root = #T(payload)", SerdeXml.xmlTagReader) val serdeCtx = unwrapOperationError(ctx, SerdeCtx("root"), errorShape, writer) renderDeserializerBody(ctx, serdeCtx, errorShape, members, writer) } @@ -184,7 +184,7 @@ open class XmlParserGenerator( // short circuit when the shape has no modeled members to deserialize write("return #T.Builder().build()", symbol) } else { - writer.write("val root = #T(payload).#T()", SerdeXml.xmlStreamReader, SerdeXml.root) + writer.write("val root = #T(payload)", SerdeXml.xmlTagReader) write("return #T(root)", deserializeFn) } } @@ -198,7 +198,7 @@ open class XmlParserGenerator( ) { withBlock("loop@while(true) {", "}") { write("val curr = ${serdeCtx.tagReader}.nextTag() ?: break@loop") - withBlock("when(curr.startTag.name.tag) {", "}") { + withBlock("when(curr.tag.name.tag) {", "}") { block(this, serdeCtx.copy(tagReader = "curr")) if (ignoreUnexpected) { write("else -> {}") @@ -263,7 +263,7 @@ open class XmlParserGenerator( ) { val memberName = member.getTrait()?.value ?: member.memberName writer.withBlock( - "${serdeCtx.tagReader}.startTag.getAttr(#S)?.let {", + "${serdeCtx.tagReader}.tag.getAttr(#S)?.let {", "}", memberName, ) { @@ -613,7 +613,7 @@ open class XmlParserGenerator( } val member = members.first() - writer.withBlock("when(${serdeCtx.tagReader}.startTag.name.tag) {", "}") { + writer.withBlock("when(${serdeCtx.tagReader}.tag.name.tag) {", "}") { val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") writeInline("#S -> builder.#L = ", name, member.defaultName()) diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt index 091d8a90a..bab0795ce 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt @@ -15,7 +15,7 @@ internal data class Ec2QueryError(val code: String?, val message: String?) @InternalApi public fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { - val response = Ec2QueryErrorResponseDeserializer.deserialize(xmlStreamReader(payload).root()) + val response = Ec2QueryErrorResponseDeserializer.deserialize(xmlTagReader(payload)) val firstError = response.errors.firstOrNull() return ErrorDetails(firstError?.code, firstError?.message, response.requestId) } @@ -25,14 +25,14 @@ public fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { * https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization */ internal object Ec2QueryErrorResponseDeserializer { - fun deserialize(root: TagReader): Ec2QueryErrorResponse = runCatching { + fun deserialize(root: XmlTagReader): Ec2QueryErrorResponse = runCatching { var errors: List? = null var requestId: String? = null - if (root.startTag.name.tag != "Response") error("expected found ${root.startTag}") + if (root.tag.name.tag != "Response") error("expected found ${root.tag}") loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.startTag.name.tag) { + when (curr.tag.name.tag) { "Errors" -> errors = Ec2QueryErrorListDeserializer.deserialize(curr) "RequestId" -> requestId = curr.data() } @@ -44,11 +44,11 @@ internal object Ec2QueryErrorResponseDeserializer { } internal object Ec2QueryErrorListDeserializer { - fun deserialize(root: TagReader): List { + fun deserialize(root: XmlTagReader): List { val errors = mutableListOf() loop@ while (true) { val curr = root.nextTag() ?: break@loop - when (curr.startTag.name.tag) { + when (curr.tag.name.tag) { "Error" -> { val el = Ec2QueryErrorDeserializer.deserialize(curr) errors.add(el) @@ -62,13 +62,13 @@ internal object Ec2QueryErrorListDeserializer { internal object Ec2QueryErrorDeserializer { - fun deserialize(root: TagReader): Ec2QueryError { + fun deserialize(root: XmlTagReader): Ec2QueryError { var code: String? = null var message: String? = null loop@ while (true) { val curr = root.nextTag() ?: break@loop - when (curr.startTag.name.tag) { + when (curr.tag.name.tag) { "Code" -> code = curr.data() "Message", "message" -> message = curr.data() } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt index f3cc9e47b..f4385810e 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt @@ -7,10 +7,9 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails import aws.smithy.kotlin.runtime.serde.* -import aws.smithy.kotlin.runtime.serde.xml.TagReader +import aws.smithy.kotlin.runtime.serde.xml.XmlTagReader import aws.smithy.kotlin.runtime.serde.xml.data -import aws.smithy.kotlin.runtime.serde.xml.root -import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader /** * Provides access to specific values regardless of message form @@ -35,7 +34,7 @@ internal data class XmlError( */ @InternalApi public fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { - val details = XmlErrorDeserializer.deserialize(xmlStreamReader(payload).root()) + val details = XmlErrorDeserializer.deserialize(xmlTagReader(payload)) return ErrorDetails(details.code, details.message, details.requestId) } @@ -43,26 +42,26 @@ public fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { * This deserializer is used for both wrapped and unwrapped restXml errors. */ internal object XmlErrorDeserializer { - fun deserialize(root: TagReader): XmlError = runCatching { + fun deserialize(root: XmlTagReader): XmlError = runCatching { var message: String? = null var code: String? = null var requestId: String? = null - val rootTagName = root.startTag.name.tag + val rootTagName = root.tag.name.tag check(rootTagName == "ErrorResponse" || rootTagName == "Error") { "expected restXml error response with root tag of or " } // wrapped error, unwrap it var errTag = root - if (root.startTag.name.tag == "ErrorResponse") { + if (root.tag.name.tag == "ErrorResponse") { errTag = root.nextTag() ?: error("expected more tags after ") } - if (errTag.startTag.name.tag == "Error") { + if (errTag.tag.name.tag == "Error") { loop@ while (true) { val curr = errTag.nextTag() ?: break@loop - when (curr.startTag.name.tag) { + when (curr.tag.name.tag) { "Code" -> code = curr.data() "Message", "message" -> message = curr.data() "RequestId" -> requestId = curr.data() @@ -75,7 +74,7 @@ internal object XmlErrorDeserializer { if (requestId == null) { loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.startTag.name.tag) { + when (curr.tag.name.tag) { "RequestId" -> requestId = curr.data() } } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt similarity index 65% rename from runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt rename to runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt index e0caf1fa3..38de566b6 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/TagReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt @@ -9,22 +9,25 @@ import aws.smithy.kotlin.runtime.io.Closeable import aws.smithy.kotlin.runtime.serde.DeserializationException /** - * An [XmlStreamReader] scoped to reading a single XML element [startTag] - * [TagReader] provides a "tag" scoped view into an XML document. Methods return + * An [XmlStreamReader] scoped to reading a single XML element [tag] + * [XmlTagReader] provides a "tag" scoped view into an XML document. Methods return * `null` when the current tag has been exhausted. */ @InternalApi -public class TagReader( - public val startTag: XmlToken.BeginElement, +public class XmlTagReader( + public val tag: XmlToken.BeginElement, private val reader: XmlStreamReader, ) : Closeable { - private var last: TagReader? = null + private var last: XmlTagReader? = null private var closed = false + /** + * Return the next actionable token or null if stream is exhausted. + */ public fun nextToken(): XmlToken? { if (closed) return null val peek = reader.peek() - if (peek.terminates(startTag)) { + if (peek.terminates(tag)) { // consume it and close the tag reader reader.nextToken() closed = true @@ -33,37 +36,31 @@ public class TagReader( return reader.nextToken() } + /** + * Check if the next token has a value, returns false if [XmlToken.EndElement] + * would be returned. + */ public fun nextHasValue(): Boolean { if (closed) return false return reader.peek() !is XmlToken.EndElement } - public fun skipNext() { - if (closed) return - reader.skipNext() - } - - public fun skipCurrent() { - if (closed) return - reader.skipCurrent() - } - override fun close(): Unit = drop() + /** + * Exhaust this [XmlTagReader] to completion. This should always + * be invoked to maintain deserialization state. + */ public fun drop() { do { val tok = nextToken() } while (tok != null) - // // consume the end token for this element - // // FIXME - consuming the next token that ends this messes up the subtree reader state, `nextToken()` will now start - // // to return more tokens - // val next = parent.peek() - // if (next.terminates(startElement)) { - // parent.nextToken() - // } } - public fun nextTag(): TagReader? { + /** + * Return an [XmlTagReader] for the next [XmlToken.BeginElement] + */ + public fun nextTag(): XmlTagReader? { last?.drop() var cand = nextToken() @@ -79,8 +76,15 @@ public class TagReader( } } +/** + * Get a [XmlTagReader] for the root tag. This is the entry point for beginning + * deserialization. + */ @InternalApi -public fun XmlStreamReader.root(): TagReader { +public fun xmlTagReader(payload: ByteArray): XmlTagReader = + xmlStreamReader(payload).root() + +private fun XmlStreamReader.root(): XmlTagReader { val start = seek() ?: error("expected start tag: last = $lastToken") return start.tagReader(this) } @@ -89,26 +93,17 @@ public fun XmlStreamReader.root(): TagReader { * Create a new reader scoped to this element. */ @InternalApi -public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): TagReader { +public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): XmlTagReader { val start = reader.lastToken as? XmlToken.BeginElement ?: error("expected start tag found ${reader.lastToken}") check(name == start.name) { "expected start tag $name but current reader state is on ${start.name}" } - return TagReader(this, reader) + return XmlTagReader(this, reader) } -/** - * Consume the next token and map the data value from it using [transform] - * - * If the next token is not [XmlToken.Text] an exception will be thrown - */ -@InternalApi -public inline fun TagReader.mapData(transform: (String) -> T): T = - transform(data()) - /** * Unwrap the next token as [XmlToken.Text] and return its' value or throw a [DeserializationException] */ @InternalApi -public fun TagReader.data(): String = +public fun XmlTagReader.data(): String = when (val next = nextToken()) { is XmlToken.Text -> next.value ?: "" null, is XmlToken.EndElement -> "" @@ -120,4 +115,4 @@ public fun TagReader.data(): String = * or the exception thrown on failure. */ @InternalApi -public fun TagReader.tryData(): Result = runCatching { data() } +public fun XmlTagReader.tryData(): Result = runCatching { data() } diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt similarity index 84% rename from runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt rename to runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt index ca22315be..3fbc14cb0 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/TagReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt @@ -7,7 +7,7 @@ package aws.smithy.kotlin.runtime.serde.xml import aws.smithy.kotlin.runtime.serde.parseInt import kotlin.test.* -class TagReaderTest { +class XmlTagReaderTest { @Test fun testNextTag() { @@ -24,13 +24,13 @@ class TagReaderTest { more """.encodeToByteArray() - val scoped = xmlStreamReader(payload).root() + val scoped = xmlTagReader(payload) val expected = listOf("a", "b", "c", "d") .map { XmlToken.BeginElement(2, it) } expected.forEach { expectedStartTag -> val tagReader = assertNotNull(scoped.nextTag()) - assertEquals(expectedStartTag, tagReader.startTag) + assertEquals(expectedStartTag, tagReader.tag) tagReader.drop() } } @@ -54,11 +54,11 @@ class TagReaderTest { """.encodeToByteArray() - val scoped = xmlStreamReader(payload).root() - assertEquals(XmlToken.BeginElement(1, "Root"), scoped.startTag) + val scoped = xmlTagReader(payload) + assertEquals(XmlToken.BeginElement(1, "Root"), scoped.tag) val s1 = assertNotNull(scoped.nextTag()) - assertEquals(XmlToken.BeginElement(2, "Child1"), s1.startTag) + assertEquals(XmlToken.BeginElement(2, "Child1"), s1.tag) val s1Elements = listOf( XmlToken.BeginElement(3, "x"), XmlToken.Text(3, "1"), @@ -70,14 +70,14 @@ class TagReaderTest { assertEquals(s1Elements, s1.allTokens()) val s2 = assertNotNull(scoped.nextTag()) - assertEquals(XmlToken.BeginElement(2, "Child2"), s2.startTag) + assertEquals(XmlToken.BeginElement(2, "Child2"), s2.tag) val aReader = assertNotNull(s2.nextTag()) - assertEquals(XmlToken.BeginElement(3, "a"), aReader.startTag) + assertEquals(XmlToken.BeginElement(3, "a"), aReader.tag) assertNull(aReader.nextTag()) val bReader = assertNotNull(s2.nextTag()) - assertEquals(XmlToken.BeginElement(3, "b"), bReader.startTag) + assertEquals(XmlToken.BeginElement(3, "b"), bReader.tag) assertEquals(XmlToken.Text(3, "4"), bReader.nextToken()) assertNull(bReader.nextToken()) bReader.drop() @@ -88,7 +88,7 @@ class TagReaderTest { selfCloseReader.drop() val s4 = assertNotNull(scoped.nextTag()) - assertEquals(XmlToken.BeginElement(2, "Child4"), s4.startTag) + assertEquals(XmlToken.BeginElement(2, "Child4"), s4.tag) } @Test @@ -115,10 +115,10 @@ class TagReaderTest { """.encodeToByteArray() - val decoder = xmlStreamReader(payload).root() + val decoder = xmlTagReader(payload) loop@while (true) { val curr = decoder.nextTag() ?: break@loop - when (curr.startTag.name.tag) { + when (curr.tag.name.tag) { "Child1" -> { assertEquals(1, curr.nextTag()?.data()?.parseInt()?.getOrNull()) assertEquals(2, curr.nextTag()?.data()?.parseInt()?.getOrNull()) @@ -136,7 +136,7 @@ class TagReaderTest { } } -fun TagReader.allTokens(): List { +fun XmlTagReader.allTokens(): List { val tokenList = mutableListOf() var nextToken: XmlToken? do { diff --git a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt index 0f01ee165..4b9b8ee97 100644 --- a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt +++ b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt @@ -7,8 +7,7 @@ package aws.smithy.kotlin.benchmarks.serde.xml import aws.smithy.kotlin.benchmarks.serde.BenchmarkBase import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.model.CountriesAndStates import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.deserializeCountriesAndStatesDocument -import aws.smithy.kotlin.runtime.serde.xml.root -import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader import kotlinx.benchmark.* import kotlinx.coroutines.runBlocking @@ -19,7 +18,7 @@ open class XmlDeserializerBenchmark : BenchmarkBase() { private fun deserialize(): CountriesAndStates = runBlocking { - val deserializer = xmlStreamReader(source).root() + val deserializer = xmlTagReader(source) deserializeCountriesAndStatesDocument(deserializer) } diff --git a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt index dbd6c94bf..6d46e43d7 100644 --- a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt +++ b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt @@ -9,8 +9,7 @@ import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.model.CountriesAnd import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.deserializeCountriesAndStatesDocument import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.serializeCountriesAndStatesDocument import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer -import aws.smithy.kotlin.runtime.serde.xml.root -import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader import kotlinx.benchmark.* import kotlinx.coroutines.runBlocking @@ -22,7 +21,7 @@ open class XmlSerializerBenchmark : BenchmarkBase() { @Setup fun init() { dataSet = runBlocking { - val deserializer = xmlStreamReader(source).root() + val deserializer = xmlTagReader(source) deserializeCountriesAndStatesDocument(deserializer) } } diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt index 2b0d14461..8a30a7979 100644 --- a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt @@ -20,7 +20,7 @@ object SerdeXmlProtocolGenerator : SerdeProtocolGenerator() { override val protocol: ShapeId = SerdeXmlProtocol.ID override fun structuredDataParser(ctx: ProtocolGenerator.GenerationContext): StructuredDataParserGenerator = - XmlParserGenerator(this, defaultTimestampFormat) + XmlParserGenerator(defaultTimestampFormat) override fun structuredDataSerializer(ctx: ProtocolGenerator.GenerationContext): StructuredDataSerializerGenerator = XmlSerializerGenerator(this, defaultTimestampFormat) diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt index 9ec628331..2e65cb4e5 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt @@ -13,14 +13,14 @@ abstract class AbstractXmlTest { expected: T, payload: String, serializerFn: (XmlSerializer, T) -> Unit, - deserializerFn: (TagReader) -> T, + deserializerFn: (XmlTagReader) -> T, ) { val serializer = XmlSerializer() serializerFn(serializer, expected) val actualPayload = serializer.toByteArray().decodeToString() assertXmlStringsEqual(payload, actualPayload) - val reader = xmlStreamReader(payload.encodeToByteArray()).root() + val reader = xmlTagReader(payload.encodeToByteArray()) val actualDeserialized = deserializerFn(reader) assertEquals(expected, actualDeserialized) } diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt index 8a9f489cd..ca083d198 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt @@ -4,8 +4,7 @@ */ package aws.smithy.kotlin.tests.serde -import aws.smithy.kotlin.runtime.serde.xml.root -import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader import aws.smithy.kotlin.tests.serde.xml.model.StructType import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument @@ -120,7 +119,7 @@ class XmlListTest : AbstractXmlTest() { """.trimIndent() // we don't round trip this because the format isn't going to match - val reader = xmlStreamReader(payload.encodeToByteArray()).root() + val reader = xmlTagReader(payload.encodeToByteArray()) val actualDeserialized = deserializeStructTypeDocument(reader) assertEquals(expected, actualDeserialized) } diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt index c5fcb5a39..9909dbd40 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt @@ -4,8 +4,7 @@ */ package aws.smithy.kotlin.tests.serde -import aws.smithy.kotlin.runtime.serde.xml.root -import aws.smithy.kotlin.runtime.serde.xml.xmlStreamReader +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader import aws.smithy.kotlin.tests.serde.xml.model.FooEnum import aws.smithy.kotlin.tests.serde.xml.model.StructType import aws.smithy.kotlin.tests.serde.xml.model.UnionType @@ -215,7 +214,7 @@ class XmlMapTest : AbstractXmlTest() { """.trimIndent() // we don't round trip this because the format isn't going to match - val reader = xmlStreamReader(payload.encodeToByteArray()).root() + val reader = xmlTagReader(payload.encodeToByteArray()) val actualDeserialized = deserializeStructTypeDocument(reader) assertEquals(expected, actualDeserialized) } From 67d557560dcb8c25a4b1f48b3c9dee46c46e3ef9 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Sat, 24 Feb 2024 23:03:56 -0500 Subject: [PATCH 16/25] reorganize fields for better names --- .../kotlin/codegen/core/RuntimeTypes.kt | 4 +- .../rendering/serde/XmlParserGenerator.kt | 16 ++++---- .../xml/Ec2QueryErrorDeserializer.kt | 8 ++-- .../xml/RestXmlErrorDeserializer.kt | 10 ++--- .../runtime/serde/xml/XmlFieldTraits.kt | 2 +- .../kotlin/runtime/serde/xml/XmlTagReader.kt | 2 +- .../kotlin/runtime/serde/xml/XmlToken.kt | 37 ++++++++++++------- .../kotlin/runtime/serde/xml/dom/XmlNode.kt | 12 +++--- .../runtime/serde/xml/XmlStreamReaderTest.kt | 34 ++++++++--------- .../runtime/serde/xml/XmlTagReaderTest.kt | 2 +- 10 files changed, 69 insertions(+), 58 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index 02f6ce3d9..b382f3de2 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -276,9 +276,9 @@ object RuntimeTypes { val XmlSerializer = symbol("XmlSerializer") val XmlUnwrappedOutput = symbol("XmlUnwrappedOutput") - val TagReader = symbol("TagReader") + val XmlTagReader = symbol("XmlTagReader") val xmlStreamReader = symbol("xmlStreamReader") - val xmlTagReader = symbol("xmlTagReader") + val xmlRootTagReader = symbol("xmlTagReader") val data = symbol("data") val tryData = symbol("tryData") } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index ce11e9469..9422a5995 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -76,7 +76,7 @@ open class XmlParserGenerator( documentMembers: List, writer: KotlinWriter, ) { - writer.write("val root = #T(payload)", SerdeXml.xmlTagReader) + writer.write("val root = #T(payload)", SerdeXml.xmlRootTagReader) val shape = ctx.model.expectShape(op.output.get()) val serdeCtx = unwrapOperationBody(ctx, SerdeCtx("root"), op, writer) @@ -119,7 +119,7 @@ open class XmlParserGenerator( ): Symbol { val symbol = ctx.symbolProvider.toSymbol(shape) return shape.documentDeserializer(ctx.settings, symbol, members) { writer -> - writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", SerdeXml.TagReader, symbol) + writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", SerdeXml.XmlTagReader, symbol) .call { val serdeCtx = SerdeCtx("reader") if (shape.isUnionShape) { @@ -148,7 +148,7 @@ open class XmlParserGenerator( val fnName = symbol.errorDeserializerName() writer.openBlock("internal fun #L(builder: #T.Builder, payload: ByteArray) {", fnName, symbol) .call { - writer.write("val root = #T(payload)", SerdeXml.xmlTagReader) + writer.write("val root = #T(payload)", SerdeXml.xmlRootTagReader) val serdeCtx = unwrapOperationError(ctx, SerdeCtx("root"), errorShape, writer) renderDeserializerBody(ctx, serdeCtx, errorShape, members, writer) } @@ -184,7 +184,7 @@ open class XmlParserGenerator( // short circuit when the shape has no modeled members to deserialize write("return #T.Builder().build()", symbol) } else { - writer.write("val root = #T(payload)", SerdeXml.xmlTagReader) + writer.write("val root = #T(payload)", SerdeXml.xmlRootTagReader) write("return #T(root)", deserializeFn) } } @@ -198,7 +198,7 @@ open class XmlParserGenerator( ) { withBlock("loop@while(true) {", "}") { write("val curr = ${serdeCtx.tagReader}.nextTag() ?: break@loop") - withBlock("when(curr.tag.name.tag) {", "}") { + withBlock("when(curr.tag.name) {", "}") { block(this, serdeCtx.copy(tagReader = "curr")) if (ignoreUnexpected) { write("else -> {}") @@ -348,7 +348,7 @@ open class XmlParserGenerator( "internal fun #L(reader: #T): #T {", "}", fnName, - SerdeXml.TagReader, + SerdeXml.XmlTagReader, symbol, ) { block(this) @@ -507,7 +507,7 @@ open class XmlParserGenerator( keySymbol, valueSymbol, nullabilitySuffix(isSparse), - SerdeXml.TagReader, + SerdeXml.XmlTagReader, ) { write("var key: #T? = null", keySymbol) write("var value: #T? = null", valueSymbol) @@ -613,7 +613,7 @@ open class XmlParserGenerator( } val member = members.first() - writer.withBlock("when(${serdeCtx.tagReader}.tag.name.tag) {", "}") { + writer.withBlock("when(${serdeCtx.tagReader}.tag.name) {", "}") { val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") writeInline("#S -> builder.#L = ", name, member.defaultName()) diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt index bab0795ce..d39134ed3 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt @@ -28,11 +28,11 @@ internal object Ec2QueryErrorResponseDeserializer { fun deserialize(root: XmlTagReader): Ec2QueryErrorResponse = runCatching { var errors: List? = null var requestId: String? = null - if (root.tag.name.tag != "Response") error("expected found ${root.tag}") + if (root.tag.name != "Response") error("expected found ${root.tag}") loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name.tag) { + when (curr.tag.name) { "Errors" -> errors = Ec2QueryErrorListDeserializer.deserialize(curr) "RequestId" -> requestId = curr.data() } @@ -48,7 +48,7 @@ internal object Ec2QueryErrorListDeserializer { val errors = mutableListOf() loop@ while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name.tag) { + when (curr.tag.name) { "Error" -> { val el = Ec2QueryErrorDeserializer.deserialize(curr) errors.add(el) @@ -68,7 +68,7 @@ internal object Ec2QueryErrorDeserializer { loop@ while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name.tag) { + when (curr.tag.name) { "Code" -> code = curr.data() "Message", "message" -> message = curr.data() } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt index f4385810e..eed58e69e 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt @@ -47,21 +47,21 @@ internal object XmlErrorDeserializer { var code: String? = null var requestId: String? = null - val rootTagName = root.tag.name.tag + val rootTagName = root.tag.name check(rootTagName == "ErrorResponse" || rootTagName == "Error") { "expected restXml error response with root tag of or " } // wrapped error, unwrap it var errTag = root - if (root.tag.name.tag == "ErrorResponse") { + if (root.tag.name == "ErrorResponse") { errTag = root.nextTag() ?: error("expected more tags after ") } - if (errTag.tag.name.tag == "Error") { + if (errTag.tag.name == "Error") { loop@ while (true) { val curr = errTag.nextTag() ?: break@loop - when (curr.tag.name.tag) { + when (curr.tag.name) { "Code" -> code = curr.data() "Message", "message" -> message = curr.data() "RequestId" -> requestId = curr.data() @@ -74,7 +74,7 @@ internal object XmlErrorDeserializer { if (requestId == null) { loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name.tag) { + when (curr.tag.name) { "RequestId" -> requestId = curr.data() } } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt index f14271d22..24c0d1acc 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt @@ -168,7 +168,7 @@ internal fun SdkFieldDescriptor.toQualifiedNames( /** * Determines if the qualified name of this field descriptor matches the given name. */ -internal fun SdkFieldDescriptor.nameMatches(other: String): Boolean = toQualifiedNames().any { it.tag == other } +internal fun SdkFieldDescriptor.nameMatches(other: String): Boolean = toQualifiedNames().any { it.toString() == other } /** * Requires that the given name matches one of this field descriptor's qualified names. diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt index 38de566b6..da486ab53 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt @@ -95,7 +95,7 @@ private fun XmlStreamReader.root(): XmlTagReader { @InternalApi public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): XmlTagReader { val start = reader.lastToken as? XmlToken.BeginElement ?: error("expected start tag found ${reader.lastToken}") - check(name == start.name) { "expected start tag $name but current reader state is on ${start.name}" } + check(qualifiedName == start.qualifiedName) { "expected start tag $qualifiedName but current reader state is on ${start.qualifiedName}" } return XmlTagReader(this, reader) } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt index c265a9617..600a7bf13 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt @@ -28,7 +28,10 @@ public sealed class XmlToken { */ @InternalApi public data class QualifiedName(public val local: String, public val prefix: String? = null) { - override fun toString(): String = tag + override fun toString(): String = when (prefix) { + null -> local + else -> "$prefix:$local" + } @InternalApi public companion object { @@ -45,11 +48,6 @@ public sealed class XmlToken { return QualifiedName(local, prefix) } } - - public val tag: String get() = when (prefix) { - null -> local - else -> "$prefix:$local" - } } /** @@ -58,31 +56,44 @@ public sealed class XmlToken { @InternalApi public data class BeginElement( override val depth: Int, - public val name: QualifiedName, + public val qualifiedName: QualifiedName, public val attributes: Map = emptyMap(), public val nsDeclarations: List = emptyList(), ) : XmlToken() { + // Convenience constructor for name-only nodes. public constructor(depth: Int, name: String) : this(depth, QualifiedName(name)) // Convenience constructor for name-only nodes with attributes. public constructor(depth: Int, name: String, attributes: Map) : this(depth, QualifiedName(name), attributes) - override fun toString(): String = "<${this.name} (${this.depth})>" + override fun toString(): String = "<${this.qualifiedName} (${this.depth})>" // convenience function for codegen public fun getAttr(qualified: String): String? = attributes[QualifiedName.from(qualified)] + + /** + * Get the qualified tag name of this element + */ + val name: String + get() = qualifiedName.toString() } /** * The closing of an XML element */ @InternalApi - public data class EndElement(override val depth: Int, public val name: QualifiedName) : XmlToken() { + public data class EndElement(override val depth: Int, public val qualifiedName: QualifiedName) : XmlToken() { // Convenience constructor for name-only nodes. public constructor(depth: Int, name: String) : this(depth, QualifiedName(name)) - override fun toString(): String = " (${this.depth})" + override fun toString(): String = " (${this.depth})" + + /** + * Get the qualified tag name of this element + */ + val name: String + get() = qualifiedName.toString() } /** @@ -109,8 +120,8 @@ public sealed class XmlToken { } override fun toString(): String = when (this) { - is BeginElement -> "<${this.name}>" - is EndElement -> "" + is BeginElement -> "<${this.qualifiedName}>" + is EndElement -> "" is Text -> "${this.value}" StartDocument -> "[StartDocument]" EndDocument -> "[EndDocument]" @@ -131,7 +142,7 @@ internal fun XmlToken?.terminates(beginToken: XmlToken?): Boolean { if (this !is XmlToken.EndElement) return false if (beginToken !is XmlToken.BeginElement) return false if (depth != beginToken.depth) return false - if (name != beginToken.name) return false + if (qualifiedName != beginToken.qualifiedName) return false return true } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt index 0d1db20e7..52496f89f 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt @@ -46,7 +46,7 @@ public class XmlNode { return parseDom(reader) } - internal fun fromToken(token: XmlToken.BeginElement): XmlNode = XmlNode(token.name).apply { + internal fun fromToken(token: XmlToken.BeginElement): XmlNode = XmlNode(token.qualifiedName).apply { attributes.putAll(token.attributes) namespaces.addAll(token.nsDeclarations) } @@ -83,8 +83,8 @@ public fun parseDom(reader: XmlStreamReader): XmlNode { is XmlToken.EndElement -> { val curr = nodeStack.top() - if (curr.name != token.name) { - throw DeserializationException("expected end of element: `${curr.name}`, found: `${token.name}`") + if (curr.name != token.qualifiedName) { + throw DeserializationException("expected end of element: `${curr.name}`, found: `${token.qualifiedName}`") } if (nodeStack.count() > 1) { @@ -121,7 +121,7 @@ internal fun formatXmlNode(curr: XmlNode, depth: Int, sb: StringBuilder, pretty: // open tag append("$indent<") - append(curr.name.tag) + append(curr.name.toString()) curr.namespaces.forEach { // namespaces declared by this node append(" xmlns") @@ -134,7 +134,7 @@ internal fun formatXmlNode(curr: XmlNode, depth: Int, sb: StringBuilder, pretty: // attributes if (curr.attributes.isNotEmpty()) append(" ") curr.attributes.forEach { - append("${it.key.tag}=\"${it.value}\"") + append("${it.key}=\"${it.value}\"") } append(">") @@ -155,7 +155,7 @@ internal fun formatXmlNode(curr: XmlNode, depth: Int, sb: StringBuilder, pretty: } append("") if (pretty && depth > 0) appendLine() diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt index b4226d6e9..b02969b09 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt @@ -112,7 +112,7 @@ class XmlStreamReaderTest { assertEquals(6, actual.size) assertIs(actual.first()) - assertEquals("payload", (actual.first() as XmlToken.BeginElement).name.local) + assertEquals("payload", (actual.first() as XmlToken.BeginElement).qualifiedName.local) } @Test @@ -231,7 +231,7 @@ class XmlStreamReaderTest { } assertIs(nt) - assertEquals("unknown", nt.name.local) + assertEquals("unknown", nt.qualifiedName.local) if (skipCurrent) { reader.skipCurrent() @@ -240,7 +240,7 @@ class XmlStreamReaderTest { } val y = reader.nextToken() as XmlToken.BeginElement - assertEquals("y", y.name.local) + assertEquals("y", y.qualifiedName.local) } @Test @@ -269,11 +269,11 @@ class XmlStreamReaderTest { assertIs(reader.peek()) val zElement = reader.nextToken() as XmlToken.BeginElement - assertEquals("z", zElement.name.local) + assertEquals("z", zElement.qualifiedName.local) reader.skipNext() val yElement = reader.nextToken() as XmlToken.BeginElement - assertEquals("y", yElement.name.local) + assertEquals("y", yElement.qualifiedName.local) } @Test @@ -312,7 +312,7 @@ class XmlStreamReaderTest { assertNull(reader.lastToken, "Expected to start with null lastToken") var peekedToken = reader.peek() assertIs(peekedToken) - assertEquals("l1", peekedToken.name.local) + assertEquals("l1", peekedToken.qualifiedName.local) assertNull(reader.lastToken, "Expected peek to not effect lastToken") reader.nextToken() // consumed l1 assertEquals(1, reader.lastToken?.depth, "Expected level 1") @@ -320,14 +320,14 @@ class XmlStreamReaderTest { peekedToken = reader.nextToken() // consumed l2 assertEquals(2, reader.lastToken?.depth, "Expected level 2") assertIs(peekedToken) - assertEquals("l2", peekedToken.name.local) + assertEquals("l2", peekedToken.qualifiedName.local) reader.peek() assertEquals(2, reader.lastToken?.depth, "Expected peek to not effect level") peekedToken = reader.nextToken() assertEquals(3, reader.lastToken?.depth, "Expected level 3") assertIs(peekedToken) - assertEquals("l3", peekedToken.name.local) + assertEquals("l3", peekedToken.qualifiedName.local) reader.peek() assertEquals(3, reader.lastToken?.depth, "Expected peek to not effect level") } @@ -459,7 +459,7 @@ class XmlStreamReaderTest { val token = unit.nextToken() assertIs(token) - assertEquals("root", token.name.local) + assertEquals("root", token.qualifiedName.local) var subTree1 = unit.subTreeReader() var subTree1Elements = subTree1.allTokens() @@ -575,25 +575,25 @@ class XmlStreamReaderTest { val rTokenTake = actual.nextToken() assertIs(rTokenPeek) - assertEquals("r", rTokenPeek.name.local) + assertEquals("r", rTokenPeek.qualifiedName.local) assertIs(aToken) - assertEquals("a", aToken.name.local) + assertEquals("a", aToken.qualifiedName.local) assertIs(rTokenTake) - assertEquals("r", rTokenTake.name.local) + assertEquals("r", rTokenTake.qualifiedName.local) val bToken = actual.peek(2) assertIs(bToken) - assertEquals("b", bToken.name.local) + assertEquals("b", bToken.qualifiedName.local) val aTokenTake = actual.nextToken() assertIs(aTokenTake) - assertEquals("a", aTokenTake.name.local) + assertEquals("a", aTokenTake.qualifiedName.local) val aCloseToken = actual.peek(5) // 1: 2: 3: 4: 5: assertIs(aCloseToken) - assertEquals("a", aCloseToken.name.local) + assertEquals("a", aCloseToken.qualifiedName.local) val restOfTokens = actual.allTokens() assertEquals(restOfTokens.size, 6) @@ -621,12 +621,12 @@ class XmlStreamReaderTest { // match begin node of depth 2 val l2Node = unit.seek { it.depth == 2 } assertIs(l2Node) - assertEquals("a", l2Node.name.local) + assertEquals("a", l2Node.qualifiedName.local) // verify next token is correct val nextNode = unit.nextToken() assertIs(nextNode) - assertEquals("b", nextNode.name.local) + assertEquals("b", nextNode.qualifiedName.local) // verify no match produces null unit = xmlStreamReader(payload) diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt index 3fbc14cb0..7200ef892 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt @@ -118,7 +118,7 @@ class XmlTagReaderTest { val decoder = xmlTagReader(payload) loop@while (true) { val curr = decoder.nextTag() ?: break@loop - when (curr.tag.name.tag) { + when (curr.tag.name) { "Child1" -> { assertEquals(1, curr.nextTag()?.data()?.parseInt()?.getOrNull()) assertEquals(2, curr.nextTag()?.data()?.parseInt()?.getOrNull()) From e912826e4611da4aaef3b83e24f99d5b960a093b Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 26 Feb 2024 08:20:42 -0500 Subject: [PATCH 17/25] api dump + changelog --- .../fb00b4ae-ffdb-4137-baa8-574848296da1.json | 8 ++++ .../api/aws-xml-protocols.api | 4 +- runtime/runtime-core/api/runtime-core.api | 8 ++++ runtime/serde/api/serde.api | 29 +++++++++++++ runtime/serde/serde-xml/api/serde-xml.api | 42 +++++++++++++------ .../kotlin/runtime/serde/xml/XmlSerializer.kt | 2 +- .../runtime/serde/xml/XmlStreamWriter.kt | 2 +- .../runtime/serde/xml/XmlStreamWriterTest.kt | 6 +-- 8 files changed, 81 insertions(+), 20 deletions(-) create mode 100644 .changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json diff --git a/.changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json b/.changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json new file mode 100644 index 000000000..431e3b97f --- /dev/null +++ b/.changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json @@ -0,0 +1,8 @@ +{ + "id": "fb00b4ae-ffdb-4137-baa8-574848296da1", + "type": "bugfix", + "description": "Refactor XML deserialization to handle flat collections", + "issues": [ + "awslabs/aws-sdk-kotlin#1220" + ] +} \ No newline at end of file diff --git a/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api b/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api index d92c49cc8..ab7db70f5 100644 --- a/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api +++ b/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api @@ -1,8 +1,8 @@ public final class aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerKt { - public static final fun parseEc2QueryErrorResponse ([BLkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun parseEc2QueryErrorResponse ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; } public final class aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerKt { - public static final fun parseRestXmlErrorResponse ([BLkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun parseRestXmlErrorResponse ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; } diff --git a/runtime/runtime-core/api/runtime-core.api b/runtime/runtime-core/api/runtime-core.api index c4f8afeef..0e37794c0 100644 --- a/runtime/runtime-core/api/runtime-core.api +++ b/runtime/runtime-core/api/runtime-core.api @@ -112,6 +112,10 @@ public final class aws/smithy/kotlin/runtime/collections/AttributesKt { public static final fun toMutableAttributes (Laws/smithy/kotlin/runtime/collections/Attributes;)Laws/smithy/kotlin/runtime/collections/MutableAttributes; } +public final class aws/smithy/kotlin/runtime/collections/CollectionExtKt { + public static final fun createOrAppend (Ljava/util/List;Ljava/lang/Object;)Ljava/util/List; +} + public abstract interface class aws/smithy/kotlin/runtime/collections/MultiMap : java/util/Map, kotlin/jvm/internal/markers/KMappedMarker { public abstract fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun getEntryValues ()Lkotlin/sequences/Sequence; @@ -2252,6 +2256,10 @@ public abstract interface class aws/smithy/kotlin/runtime/util/PropertyProvider public abstract fun getProperty (Ljava/lang/String;)Ljava/lang/String; } +public final class aws/smithy/kotlin/runtime/util/ResultExtKt { + public static final fun mapErr (Ljava/lang/Object;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; +} + public final class aws/smithy/kotlin/runtime/util/SingleFlightGroup { public fun ()V public final fun singleFlight (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/runtime/serde/api/serde.api b/runtime/serde/api/serde.api index c9d0d90b6..8be9fc0e9 100644 --- a/runtime/serde/api/serde.api +++ b/runtime/serde/api/serde.api @@ -39,6 +39,10 @@ public final class aws/smithy/kotlin/runtime/serde/DeserializerKt { public static final fun deserializeStruct (Laws/smithy/kotlin/runtime/serde/Deserializer;Laws/smithy/kotlin/runtime/serde/SdkObjectDescriptor;Lkotlin/jvm/functions/Function1;)V } +public final class aws/smithy/kotlin/runtime/serde/ExceptionsKt { + public static final fun getOrDeserializeErr (Ljava/lang/Object;Lkotlin/jvm/functions/Function0;)Ljava/lang/Object; +} + public abstract interface class aws/smithy/kotlin/runtime/serde/FieldTrait { } @@ -64,6 +68,31 @@ public abstract interface class aws/smithy/kotlin/runtime/serde/MapSerializer : public abstract fun mapEntry (Ljava/lang/String;Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;Lkotlin/jvm/functions/Function1;)V } +public final class aws/smithy/kotlin/runtime/serde/ParsersKt { + public static final fun parse (Ljava/lang/Object;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; + public static final fun parse (Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; + public static final fun parseBigDecimal (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseBigDecimal (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseBigInteger (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseBigInteger (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseBoolean (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseBoolean (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseByte (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseByte (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseDouble (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseDouble (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseFloat (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseFloat (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseInt (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseInt (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseLong (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseLong (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseShort (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseShort (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseTimestamp (Ljava/lang/Object;Laws/smithy/kotlin/runtime/time/TimestampFormat;)Ljava/lang/Object; + public static final fun parseTimestamp (Ljava/lang/String;Laws/smithy/kotlin/runtime/time/TimestampFormat;)Ljava/lang/Object; +} + public abstract interface class aws/smithy/kotlin/runtime/serde/PrimitiveDeserializer { public abstract fun deserializeBigDecimal ()Ljava/math/BigDecimal; public abstract fun deserializeBigInteger ()Ljava/math/BigInteger; diff --git a/runtime/serde/serde-xml/api/serde-xml.api b/runtime/serde/serde-xml/api/serde-xml.api index 846e3d41f..2b7c80a66 100644 --- a/runtime/serde/serde-xml/api/serde-xml.api +++ b/runtime/serde/serde-xml/api/serde-xml.api @@ -47,16 +47,6 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlCollectionValueNamespa public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V } -public final class aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer : aws/smithy/kotlin/runtime/serde/Deserializer { - public fun (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;Z)V - public synthetic fun (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun ([BZ)V - public synthetic fun ([BZILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun deserializeList (Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$ElementIterator; - public fun deserializeMap (Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$EntryIterator; - public fun deserializeStruct (Laws/smithy/kotlin/runtime/serde/SdkObjectDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$FieldIterator; -} - public final class aws/smithy/kotlin/runtime/serde/xml/XmlError : aws/smithy/kotlin/runtime/serde/FieldTrait { public static final field INSTANCE Laws/smithy/kotlin/runtime/serde/xml/XmlError; public final fun getErrorTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; @@ -153,6 +143,7 @@ public abstract interface class aws/smithy/kotlin/runtime/serde/xml/XmlStreamRea public abstract fun getLastToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun peek (I)Laws/smithy/kotlin/runtime/serde/xml/XmlToken; + public abstract fun skipCurrent ()V public abstract fun skipNext ()V public abstract fun subTreeReader (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$SubtreeStartDepth;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader; } @@ -199,6 +190,23 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterKt { public static synthetic fun xmlStreamWriter$default (ZILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; } +public final class aws/smithy/kotlin/runtime/serde/xml/XmlTagReader : java/io/Closeable { + public fun (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;)V + public fun close ()V + public final fun drop ()V + public final fun getTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; + public final fun nextHasValue ()Z + public final fun nextTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; + public final fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; +} + +public final class aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderKt { + public static final fun data (Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader;)Ljava/lang/String; + public static final fun tagReader (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;)Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; + public static final fun tryData (Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader;)Ljava/lang/Object; + public static final fun xmlTagReader ([B)Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; +} + public abstract class aws/smithy/kotlin/runtime/serde/xml/XmlToken { public abstract fun getDepth ()I public fun toString ()Ljava/lang/String; @@ -216,10 +224,12 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement : a public final fun copy (ILaws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName;Ljava/util/Map;Ljava/util/List;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; public static synthetic fun copy$default (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;ILaws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName;Ljava/util/Map;Ljava/util/List;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; public fun equals (Ljava/lang/Object;)Z + public final fun getAttr (Ljava/lang/String;)Ljava/lang/String; public final fun getAttributes ()Ljava/util/Map; public fun getDepth ()I - public final fun getName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; + public final fun getName ()Ljava/lang/String; public final fun getNsDeclarations ()Ljava/util/List; + public final fun getQualifiedName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -238,7 +248,8 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$EndElement : aws public static synthetic fun copy$default (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$EndElement;ILaws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$EndElement; public fun equals (Ljava/lang/Object;)Z public fun getDepth ()I - public final fun getName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; + public final fun getName ()Ljava/lang/String; + public final fun getQualifiedName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -258,6 +269,7 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$Namespace { } public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName { + public static final field Companion Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName$Companion; public fun (Ljava/lang/String;Ljava/lang/String;)V public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; @@ -267,11 +279,14 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName { public fun equals (Ljava/lang/Object;)Z public final fun getLocal ()Ljava/lang/String; public final fun getPrefix ()Ljava/lang/String; - public final fun getTag ()Ljava/lang/String; public fun hashCode ()I public fun toString ()Ljava/lang/String; } +public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName$Companion { + public final fun from (Ljava/lang/String;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; +} + public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$StartDocument : aws/smithy/kotlin/runtime/serde/xml/XmlToken { public static final field INSTANCE Laws/smithy/kotlin/runtime/serde/xml/XmlToken$StartDocument; public fun getDepth ()I @@ -299,6 +314,7 @@ public final class aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXml public fun getLastToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public fun peek (I)Laws/smithy/kotlin/runtime/serde/xml/XmlToken; + public fun skipCurrent ()V public fun skipNext ()V public fun subTreeReader (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$SubtreeStartDepth;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader; } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt index ff50a3d3c..825993143 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlSerializer.kt @@ -184,7 +184,7 @@ public class XmlSerializer(private val xmlWriter: XmlStreamWriter = xmlStreamWri xmlWriter.text(value.toPlainString()) } - private fun serializeNumber(value: Number): Unit = xmlWriter.data(value) + private fun serializeNumber(value: Number): Unit = xmlWriter.text(value) override fun serializeString(value: String) { xmlWriter.text(value) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt index a4f7b6176..1c529d912 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter.kt @@ -78,7 +78,7 @@ public interface XmlStreamWriter { } @InternalApi -public fun XmlStreamWriter.data(text: Number) { +public fun XmlStreamWriter.text(text: Number) { this.text(text.toString()) } diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt index 38ea083e0..316f7034e 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterTest.kt @@ -167,7 +167,7 @@ fun writeMessage(writer: XmlStreamWriter, message: Message) { writer.apply { startTag("message") startTag("id") - data(message.id) + text(message.id) endTag("id") startTag("text") text(message.text) @@ -190,7 +190,7 @@ fun writeUser(writer: XmlStreamWriter, user: User) { writer.text(user.name) writer.endTag("name") writer.startTag("followers_count") - writer.data(user.followersCount) + writer.text(user.followersCount) writer.endTag("followers_count") writer.endTag("user") } @@ -200,7 +200,7 @@ fun writeDoublesArray(writer: XmlStreamWriter, doubles: Array?) { if (doubles != null) { for (value in doubles) { writer.startTag("position") - writer.data(value) + writer.text(value) writer.endTag("position") } } From 5b7f8cf509e788d6962ed3c14bbbbe882cec7b69 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 26 Feb 2024 09:07:00 -0500 Subject: [PATCH 18/25] update benchmark baseline --- tests/benchmarks/serde-benchmarks/README.md | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/benchmarks/serde-benchmarks/README.md b/tests/benchmarks/serde-benchmarks/README.md index a618a0e35..6cca8e91f 100644 --- a/tests/benchmarks/serde-benchmarks/README.md +++ b/tests/benchmarks/serde-benchmarks/README.md @@ -8,20 +8,20 @@ This project contains micro benchmarks for the serialization implementation(s). ./gradlew :runtime:serde:serde-benchmarks:jvmBenchmark ``` -Baseline `0.7.8-beta` on EC2 **[m5.4xlarge](https://aws.amazon.com/ec2/instance-types/m5/)** in **OpenJK 1.8.0_312**: +Baseline on EC2 **[m5.4xlarge](https://aws.amazon.com/ec2/instance-types/m5/)** in **Corretto-17.0.10.8.1**: ``` jvm summary: -Benchmark (sourceFilename) Mode Cnt Score Error Units -a.s.k.b.s.json.CitmBenchmark.tokensBenchmark N/A avgt 5 12.530 ± 0.611 ms/op -a.s.k.b.s.json.TwitterBenchmark.deserializeBenchmark N/A avgt 5 10.148 ± 7.515 ms/op -a.s.k.b.s.json.TwitterBenchmark.serializeBenchmark N/A avgt 5 1.534 ± 1.608 ms/op -a.s.k.b.s.json.TwitterBenchmark.tokensBenchmark N/A avgt 5 6.381 ± 3.615 ms/op -a.s.k.b.s.xml.BufferStreamWriterBenchmark.serializeBenchmark N/A avgt 5 11.746 ± 0.262 ms/op -a.s.k.b.s.xml.XmlDeserializerBenchmark.deserializeBenchmark N/A avgt 5 90.697 ± 1.178 ms/op -a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark countries-states.xml avgt 5 22.665 ± 0.473 ms/op -a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark kotlin-article.xml avgt 5 0.734 ± 0.017 ms/op -a.s.k.b.s.xml.XmlSerializerBenchmark.serializeBenchmark N/A avgt 5 27.324 ± 31.331 ms/op +Benchmark (sourceFilename) Mode Cnt Score Error Units +a.s.k.b.s.json.CitmBenchmark.tokensBenchmark N/A avgt 5 10.066 ± 0.033 ms/op +a.s.k.b.s.json.TwitterBenchmark.deserializeBenchmark N/A avgt 5 7.295 ± 0.033 ms/op +a.s.k.b.s.json.TwitterBenchmark.serializeBenchmark N/A avgt 5 1.498 ± 0.026 ms/op +a.s.k.b.s.json.TwitterBenchmark.tokensBenchmark N/A avgt 5 4.431 ± 0.029 ms/op +a.s.k.b.s.xml.BufferStreamWriterBenchmark.serializeBenchmark N/A avgt 5 10.540 ± 0.134 ms/op +a.s.k.b.s.xml.XmlDeserializerBenchmark.deserializeBenchmark N/A avgt 5 33.566 ± 0.074 ms/op +a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark countries-states.xml avgt 5 25.200 ± 0.079 ms/op +a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark kotlin-article.xml avgt 5 0.846 ± 0.003 ms/op +a.s.k.b.s.xml.XmlSerializerBenchmark.serializeBenchmark N/A avgt 5 21.714 ± 0.385 ms/op ``` ## JSON Data @@ -44,7 +44,7 @@ Raw data was imported from multiple sources: ## Benchmarks -The `model` folder contains hand rolled Smithy models for some of the benchmarks. The `smithy-benchmarks-codegen` project -contains the codegen support to generate these models. +The `model` folder contains hand rolled Smithy models for some of the benchmarks. +The `tests/codegen/serde-codegen-support` module contains the codegen support to generate these models. These models are generated as part of the build. Until you run `assemble` you may see errors in your IDE. \ No newline at end of file From b78724ada28fcf83e68f9e2ab0d9a3a3ef1e5800 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 26 Feb 2024 09:19:17 -0500 Subject: [PATCH 19/25] fix -warn --- tests/codegen/serde-tests/build.gradle.kts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/codegen/serde-tests/build.gradle.kts b/tests/codegen/serde-tests/build.gradle.kts index e963e664e..ee2f94ab5 100644 --- a/tests/codegen/serde-tests/build.gradle.kts +++ b/tests/codegen/serde-tests/build.gradle.kts @@ -89,6 +89,8 @@ kotlin.sourceSets.getByName("main") { tasks.withType { dependsOn(stageGeneratedSources) + // generated code has warnings unfortunately, see https://github.com/awslabs/aws-sdk-kotlin/issues/1169 + kotlinOptions.allWarningsAsErrors = false } tasks.clean.configure { From a1fe7c67ec4707de579d6224af188d617b760450 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Mon, 26 Feb 2024 12:25:01 -0500 Subject: [PATCH 20/25] fix member names --- .../codegen/rendering/serde/XmlParserGenerator.kt | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index 9422a5995..9e7b53982 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -248,8 +248,9 @@ open class XmlParserGenerator( writer.deserializeLoop(serdeCtx) { innerCtx -> payloadMembers.forEach { member -> val name = member.getTrait()?.value ?: member.memberName + write("// ${member.memberName} ${escape(member.id.toString())}") - writeInline("#S -> builder.#L = ", name, member.defaultName()) + writeInline("#S -> builder.#L = ", name, ctx.symbolProvider.toMemberName(member)) deserializeMember(ctx, innerCtx, member, writer) } } @@ -267,7 +268,7 @@ open class XmlParserGenerator( "}", memberName, ) { - writeInline("builder.#L = ", member.defaultName()) + writeInline("builder.#L = ", ctx.symbolProvider.toMemberName(member)) deserializePrimitiveMember(ctx, member, "it", textExprIsResult = false, this) } } @@ -383,15 +384,17 @@ open class XmlParserGenerator( private fun flatCollectionAccumulatorExpr( ctx: ProtocolGenerator.GenerationContext, member: MemberShape, - ): String = - when (val container = ctx.model.expectShape(member.container)) { - is StructureShape -> "builder.${member.defaultName()}" + ): String { + val escapedMemberName = ctx.symbolProvider.toMemberName(member) + return when (val container = ctx.model.expectShape(member.container)) { + is StructureShape -> "builder.$escapedMemberName" is UnionShape -> { val unionVariantName = member.unionVariantName() "value?.as${unionVariantName}OrNull()" } else -> error("unexpected container shape $container for member $member") } + } private fun deserializeFlatList( ctx: ProtocolGenerator.GenerationContext, @@ -616,7 +619,7 @@ open class XmlParserGenerator( writer.withBlock("when(${serdeCtx.tagReader}.tag.name) {", "}") { val name = member.getTrait()?.value ?: member.memberName write("// ${member.memberName} ${escape(member.id.toString())}") - writeInline("#S -> builder.#L = ", name, member.defaultName()) + writeInline("#S -> builder.#L = ", name, ctx.symbolProvider.toMemberName(member)) deserializeMember(ctx, serdeCtx, member, writer) } } From 4a98f5e24e2b01837122c4126a70cf4826ea391a Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Tue, 27 Feb 2024 12:33:39 -0500 Subject: [PATCH 21/25] feedback --- .../codegen/rendering/serde/SerdeExt.kt | 2 +- .../rendering/serde/XmlParserGenerator.kt | 105 +++-- .../xml/Ec2QueryErrorDeserializer.kt | 14 +- .../xml/RestXmlErrorDeserializer.kt | 14 +- .../smithy/kotlin/runtime/util/ResultExt.kt | 3 + .../smithy/kotlin/runtime/serde/Parsers.kt | 1 + .../runtime/serde/xml/XmlDeserializer.kt | 416 ++++++++++++++++++ .../serde/xml/XmlPrimitiveDeserializer.kt | 75 ++++ .../runtime/serde/xml/XmlStreamReader.kt | 5 - .../kotlin/runtime/serde/xml/XmlTagReader.kt | 27 +- .../kotlin/runtime/serde/xml/XmlToken.kt | 33 +- .../deserialization/LexingXmlStreamReader.kt | 8 - .../kotlin/runtime/serde/xml/dom/XmlNode.kt | 6 +- .../runtime/serde/xml/XmlStreamReaderTest.kt | 53 +-- .../runtime/serde/xml/XmlTagReaderTest.kt | 2 +- 15 files changed, 625 insertions(+), 139 deletions(-) create mode 100644 runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt create mode 100644 runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt index 036cb3d79..071d8ba36 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt @@ -237,7 +237,7 @@ fun TimestampFormatTrait.Format.toRuntimeEnum(writer: KotlinWriter): String { TimestampFormatTrait.Format.EPOCH_SECONDS -> "EPOCH_SECONDS" TimestampFormatTrait.Format.DATE_TIME -> "ISO_8601" TimestampFormatTrait.Format.HTTP_DATE -> "RFC_5322" - TimestampFormatTrait.Format.UNKNOWN -> error("unknown timestamp format trait") + else -> throw CodegenException("unknown timestamp format: $this") } return writer.format("#T.#L", RuntimeTypes.Core.TimestampFormat, enum) } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index 9e7b53982..1e1bcb3e2 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -8,11 +8,11 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.core.* -import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde -import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes.Serde.SerdeXml +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex +import software.amazon.smithy.kotlin.codegen.model.traits.SyntheticClone import software.amazon.smithy.kotlin.codegen.model.traits.UnwrappedXmlOutput import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.model.shapes.* @@ -22,6 +22,7 @@ import software.amazon.smithy.model.traits.XmlAttributeTrait import software.amazon.smithy.model.traits.XmlFlattenedTrait import software.amazon.smithy.model.traits.XmlNameTrait import software.amazon.smithy.utils.StringUtils +import kotlin.jvm.optionals.getOrDefault /** * XML parser generator based on common deserializer interface and XML serde descriptors @@ -76,7 +77,7 @@ open class XmlParserGenerator( documentMembers: List, writer: KotlinWriter, ) { - writer.write("val root = #T(payload)", SerdeXml.xmlRootTagReader) + writer.write("val root = #T(payload)", RuntimeTypes.Serde.SerdeXml.xmlRootTagReader) val shape = ctx.model.expectShape(op.output.get()) val serdeCtx = unwrapOperationBody(ctx, SerdeCtx("root"), op, writer) @@ -119,13 +120,13 @@ open class XmlParserGenerator( ): Symbol { val symbol = ctx.symbolProvider.toSymbol(shape) return shape.documentDeserializer(ctx.settings, symbol, members) { writer -> - writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", SerdeXml.XmlTagReader, symbol) + writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", RuntimeTypes.Serde.SerdeXml.XmlTagReader, symbol) .call { val serdeCtx = SerdeCtx("reader") if (shape.isUnionShape) { writer.write("var value: #T? = null", symbol) renderDeserializerBody(ctx, serdeCtx, shape, members.toList(), writer) - writer.write("return value ?: throw #T(#S)", Serde.DeserializationException, "Deserialized union value unexpectedly null: ${symbol.name}") + writer.write("return value ?: throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "Deserialized union value unexpectedly null: ${symbol.name}") } else { writer.write("val builder = #T.Builder()", symbol) renderDeserializerBody(ctx, serdeCtx, shape, members.toList(), writer) @@ -148,7 +149,7 @@ open class XmlParserGenerator( val fnName = symbol.errorDeserializerName() writer.openBlock("internal fun #L(builder: #T.Builder, payload: ByteArray) {", fnName, symbol) .call { - writer.write("val root = #T(payload)", SerdeXml.xmlRootTagReader) + writer.write("val root = #T(payload)", RuntimeTypes.Serde.SerdeXml.xmlRootTagReader) val serdeCtx = unwrapOperationError(ctx, SerdeCtx("root"), errorShape, writer) renderDeserializerBody(ctx, serdeCtx, errorShape, members, writer) } @@ -184,7 +185,7 @@ open class XmlParserGenerator( // short circuit when the shape has no modeled members to deserialize write("return #T.Builder().build()", symbol) } else { - writer.write("val root = #T(payload)", SerdeXml.xmlRootTagReader) + writer.write("val root = #T(payload)", RuntimeTypes.Serde.SerdeXml.xmlRootTagReader) write("return #T(root)", deserializeFn) } } @@ -196,12 +197,14 @@ open class XmlParserGenerator( ignoreUnexpected: Boolean = true, block: KotlinWriter.(SerdeCtx) -> Unit, ) { - withBlock("loop@while(true) {", "}") { - write("val curr = ${serdeCtx.tagReader}.nextTag() ?: break@loop") - withBlock("when(curr.tag.name) {", "}") { + withBlock("loop@while (true) {", "}") { + write("val curr = #L.nextTag() ?: break@loop", serdeCtx.tagReader) + withBlock("when (curr.tagName) {", "}") { block(this, serdeCtx.copy(tagReader = "curr")) if (ignoreUnexpected) { write("else -> {}") + } else { + write("else -> throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "Unexpected tag \${curr.tag}") } } // maintain stream reader state by dropping the current element and all it's children @@ -211,6 +214,20 @@ open class XmlParserGenerator( } } + private fun originalMember(ctx: ProtocolGenerator.GenerationContext, member: MemberShape): MemberShape { + val containerShapeId = ctx.model.expectShape(member.container).getTrait()?.archetype ?: member.container + val container = ctx.model.expectShape(containerShapeId) + return container.getMember(member.memberName).getOrDefault(member) + } + + private fun KotlinWriter.writeMemberDebugComment( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + ) { + val originalMember = originalMember(ctx, member) + write("// ${originalMember.memberName} ${escape(originalMember.id.toString())}") + } + private fun deserializeUnion( ctx: ProtocolGenerator.GenerationContext, serdeCtx: SerdeCtx, @@ -220,7 +237,7 @@ open class XmlParserGenerator( writer.deserializeLoop(serdeCtx) { innerCtx -> members.forEach { member -> val name = member.getTrait()?.value ?: member.memberName - write("// ${member.memberName} ${escape(member.id.toString())}") + writeMemberDebugComment(ctx, member) val unionTypeName = member.unionTypeName(ctx) withBlock("#S -> value = #L(", ")", name, unionTypeName) { deserializeMember(ctx, innerCtx, member, writer) @@ -236,12 +253,14 @@ open class XmlParserGenerator( writer: KotlinWriter, ) { // split attribute members and non attribute members - val attributeMembers = members.filter { it.hasTrait() } + val (attributeMembers, payloadMembers) = members.partition { + it.hasTrait() + } + attributeMembers.forEach { member -> deserializeAttributeMember(ctx, serdeCtx, member, writer) } - val payloadMembers = members.filterNot { it.hasTrait() } // don't generate a parse loop if no attribute members if (payloadMembers.isEmpty()) return writer.write("") @@ -264,8 +283,9 @@ open class XmlParserGenerator( ) { val memberName = member.getTrait()?.value ?: member.memberName writer.withBlock( - "${serdeCtx.tagReader}.tag.getAttr(#S)?.let {", + "#L.tag.getAttr(#S)?.let {", "}", + serdeCtx.tagReader, memberName, ) { writeInline("builder.#L = ", ctx.symbolProvider.toMemberName(member)) @@ -297,12 +317,12 @@ open class XmlParserGenerator( } ShapeType.STRUCTURE, ShapeType.UNION -> { val deserializeFn = documentDeserializer(ctx, target) - writer.write("#T(${serdeCtx.tagReader})", deserializeFn) + writer.write("#T(#L)", deserializeFn, serdeCtx.tagReader) } else -> deserializePrimitiveMember( ctx, member, - writer.format("${serdeCtx.tagReader}.#T()", SerdeXml.tryData), + writer.format("#L.#T()", serdeCtx.tagReader, RuntimeTypes.Serde.SerdeXml.tryData), textExprIsResult = true, writer, ) @@ -318,6 +338,7 @@ open class XmlParserGenerator( val shapeName = StringUtils.capitalize(target.id.getName(ctx.service)) return "${shapeName}ShapeDeserializer.kt" } + private fun Shape.shapeDeserializer( ctx: ProtocolGenerator.GenerationContext, block: (fnName: String, writer: KotlinWriter) -> Unit, @@ -349,7 +370,7 @@ open class XmlParserGenerator( "internal fun #L(reader: #T): #T {", "}", fnName, - SerdeXml.XmlTagReader, + RuntimeTypes.Serde.SerdeXml.XmlTagReader, symbol, ) { block(this) @@ -378,7 +399,7 @@ open class XmlParserGenerator( } write("return result") } - writer.write("#T(${serdeCtx.tagReader})", deserializeFn) + writer.write("#T(#L)", deserializeFn, serdeCtx.tagReader) } private fun flatCollectionAccumulatorExpr( @@ -421,7 +442,7 @@ open class XmlParserGenerator( val isSparse = target.hasTrait() with(writer) { if (isSparse) { - openBlock("val el = if (${serdeCtx.tagReader}.nextHasValue()) {") + openBlock("val el = if (#L.nextHasValue()) {", serdeCtx.tagReader) .call { deserializeMember(ctx, serdeCtx, target.member, this) } @@ -457,8 +478,9 @@ open class XmlParserGenerator( } write("return result") } - writer.write("#T(${serdeCtx.tagReader})", deserializeFn) + writer.write("#T(#L)", deserializeFn, serdeCtx.tagReader) } + private fun deserializeFlatMap( ctx: ProtocolGenerator.GenerationContext, serdeCtx: SerdeCtx, @@ -480,7 +502,7 @@ open class XmlParserGenerator( nullabilitySuffix(isSparse), ) val deserializeEntryFn = deserializeMapEntry(ctx, target) - write("#T(dest, ${serdeCtx.tagReader})", deserializeEntryFn) + write("#T(dest, #L)", deserializeEntryFn, serdeCtx.tagReader) write("dest") } } @@ -510,7 +532,7 @@ open class XmlParserGenerator( keySymbol, valueSymbol, nullabilitySuffix(isSparse), - SerdeXml.XmlTagReader, + RuntimeTypes.Serde.SerdeXml.XmlTagReader, ) { write("var key: #T? = null", keySymbol) write("var value: #T? = null", valueSymbol) @@ -534,9 +556,9 @@ open class XmlParserGenerator( deserializeMember(ctx, innerCtx, map.value, this) } } - write("if (key == null) throw #T(#S)", Serde.DeserializationException, "missing key map entry") + write("if (key == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing key map entry") if (!isSparse) { - write("if (value == null) throw #T(#S)", Serde.DeserializationException, "missing value map entry") + write("if (value == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing value map entry") } write("dest[key] = value") } @@ -554,8 +576,8 @@ open class XmlParserGenerator( val target = ctx.model.expectShape(member.target) val parseFn = when (target.type) { - ShapeType.BLOB -> writer.format("#T { it.#T() } ", Serde.parse, RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes) - ShapeType.BOOLEAN -> writer.format("#T()", Serde.parseBoolean) + ShapeType.BLOB -> writer.format("#T { it.#T() } ", RuntimeTypes.Serde.parse, RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes) + ShapeType.BOOLEAN -> writer.format("#T()", RuntimeTypes.Serde.parseBoolean) ShapeType.STRING -> { if (!textExprIsResult) { writer.write(textExpr) @@ -567,29 +589,26 @@ open class XmlParserGenerator( ShapeType.TIMESTAMP -> { val trait = member.getTrait() ?: target.getTrait() val tsFormat = trait?.format ?: defaultTimestampFormat - // val fromArg = writer.format("curr.#T()") - // val fmtExpr = writer.parseInstantExpr(fromArg, tsFormat) - // writer.write(fmtExpr) val runtimeEnum = tsFormat.toRuntimeEnum(writer) - writer.format("#T(#L)", Serde.parseTimestamp, runtimeEnum) + writer.format("#T(#L)", RuntimeTypes.Serde.parseTimestamp, runtimeEnum) } - ShapeType.BYTE -> writer.format("#T()", Serde.parseByte) - ShapeType.SHORT -> writer.format("#T()", Serde.parseShort) - ShapeType.INTEGER -> writer.format("#T()", Serde.parseInt) - ShapeType.LONG -> writer.format("#T()", Serde.parseLong) - ShapeType.FLOAT -> writer.format("#T()", Serde.parseFloat) - ShapeType.DOUBLE -> writer.format("#T()", Serde.parseDouble) - ShapeType.BIG_DECIMAL -> writer.format("#T()", Serde.parseBigDecimal) - ShapeType.BIG_INTEGER -> writer.format("#T()", Serde.parseBigInteger) + ShapeType.BYTE -> writer.format("#T()", RuntimeTypes.Serde.parseByte) + ShapeType.SHORT -> writer.format("#T()", RuntimeTypes.Serde.parseShort) + ShapeType.INTEGER -> writer.format("#T()", RuntimeTypes.Serde.parseInt) + ShapeType.LONG -> writer.format("#T()", RuntimeTypes.Serde.parseLong) + ShapeType.FLOAT -> writer.format("#T()", RuntimeTypes.Serde.parseFloat) + ShapeType.DOUBLE -> writer.format("#T()", RuntimeTypes.Serde.parseDouble) + ShapeType.BIG_DECIMAL -> writer.format("#T()", RuntimeTypes.Serde.parseBigDecimal) + ShapeType.BIG_INTEGER -> writer.format("#T()", RuntimeTypes.Serde.parseBigInteger) ShapeType.ENUM -> { if (!textExprIsResult) { writer.write("#T.fromValue(#L)", ctx.symbolProvider.toSymbol(target), textExpr) return } - writer.format("#T { #T.fromValue(it) } ", Serde.parse, ctx.symbolProvider.toSymbol(target)) + writer.format("#T { #T.fromValue(it) } ", RuntimeTypes.Serde.parse, ctx.symbolProvider.toSymbol(target)) } ShapeType.INT_ENUM -> { - writer.format("#T { #T.fromValue(it.toInt()) } ", Serde.parse, ctx.symbolProvider.toSymbol(target)) + writer.format("#T { #T.fromValue(it.toInt()) } ", RuntimeTypes.Serde.parse, ctx.symbolProvider.toSymbol(target)) } else -> error("unknown primitive member shape $member") } @@ -600,7 +619,7 @@ open class XmlParserGenerator( .callIf(parseFn != null) { writer.write(".#L", parseFn) } - .write(".#T { #S }", Serde.getOrDeserializeErr, escapedErrMessage) + .write(".#T { #S }", RuntimeTypes.Serde.getOrDeserializeErr, escapedErrMessage) .dedent() } @@ -616,9 +635,9 @@ open class XmlParserGenerator( } val member = members.first() - writer.withBlock("when(${serdeCtx.tagReader}.tag.name) {", "}") { + writer.withBlock("when (#L.tagName) {", "}", serdeCtx.tagReader) { val name = member.getTrait()?.value ?: member.memberName - write("// ${member.memberName} ${escape(member.id.toString())}") + writeMemberDebugComment(ctx, member) writeInline("#S -> builder.#L = ", name, ctx.symbolProvider.toMemberName(member)) deserializeMember(ctx, serdeCtx, member, writer) } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt index d39134ed3..4ce0e7c98 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt @@ -14,7 +14,7 @@ internal data class Ec2QueryErrorResponse(val errors: List, val r internal data class Ec2QueryError(val code: String?, val message: String?) @InternalApi -public fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { +public suspend fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { val response = Ec2QueryErrorResponseDeserializer.deserialize(xmlTagReader(payload)) val firstError = response.errors.firstOrNull() return ErrorDetails(firstError?.code, firstError?.message, response.requestId) @@ -28,11 +28,11 @@ internal object Ec2QueryErrorResponseDeserializer { fun deserialize(root: XmlTagReader): Ec2QueryErrorResponse = runCatching { var errors: List? = null var requestId: String? = null - if (root.tag.name != "Response") error("expected found ${root.tag}") + if (root.tagName != "Response") error("expected found ${root.tag}") loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name) { + when (curr.tagName) { "Errors" -> errors = Ec2QueryErrorListDeserializer.deserialize(curr) "RequestId" -> requestId = curr.data() } @@ -46,9 +46,9 @@ internal object Ec2QueryErrorResponseDeserializer { internal object Ec2QueryErrorListDeserializer { fun deserialize(root: XmlTagReader): List { val errors = mutableListOf() - loop@ while (true) { + loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name) { + when (curr.tagName) { "Error" -> { val el = Ec2QueryErrorDeserializer.deserialize(curr) errors.add(el) @@ -66,9 +66,9 @@ internal object Ec2QueryErrorDeserializer { var code: String? = null var message: String? = null - loop@ while (true) { + loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name) { + when (curr.tagName) { "Code" -> code = curr.data() "Message", "message" -> message = curr.data() } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt index eed58e69e..60e494425 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt @@ -33,7 +33,7 @@ internal data class XmlError( * Returns parsed data in normalized form or throws [DeserializationException] if response cannot be parsed. */ @InternalApi -public fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { +public suspend fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { val details = XmlErrorDeserializer.deserialize(xmlTagReader(payload)) return ErrorDetails(details.code, details.message, details.requestId) } @@ -47,21 +47,21 @@ internal object XmlErrorDeserializer { var code: String? = null var requestId: String? = null - val rootTagName = root.tag.name + val rootTagName = root.tagName check(rootTagName == "ErrorResponse" || rootTagName == "Error") { "expected restXml error response with root tag of or " } // wrapped error, unwrap it var errTag = root - if (root.tag.name == "ErrorResponse") { + if (root.tagName == "ErrorResponse") { errTag = root.nextTag() ?: error("expected more tags after ") } - if (errTag.tag.name == "Error") { - loop@ while (true) { + if (errTag.tagName == "Error") { + loop@while (true) { val curr = errTag.nextTag() ?: break@loop - when (curr.tag.name) { + when (curr.tagName) { "Code" -> code = curr.data() "Message", "message" -> message = curr.data() "RequestId" -> requestId = curr.data() @@ -74,7 +74,7 @@ internal object XmlErrorDeserializer { if (requestId == null) { loop@while (true) { val curr = root.nextTag() ?: break@loop - when (curr.tag.name) { + when (curr.tagName) { "RequestId" -> requestId = curr.data() } } diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt index 6ba8947d3..b2ed33b0d 100644 --- a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt @@ -4,10 +4,13 @@ */ package aws.smithy.kotlin.runtime.util +import aws.smithy.kotlin.runtime.InternalApi + /** * Maps the exception to a new error if this instance represents [failure][Result.isFailure], leaving * a [success][Result.isSuccess] value untouched. */ +@InternalApi public inline fun Result.mapErr(onFailure: (Throwable) -> Throwable): Result = when (val ex = exceptionOrNull()) { null -> this diff --git a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt index 7c86d51ea..3991f0f1a 100644 --- a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt +++ b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt @@ -34,6 +34,7 @@ public fun String.parseDouble(): Result = parse(String::toDouble) @InternalApi public fun String.parseByte(): Result = parse { it.toInt().toByte() } +@InternalApi public fun String.parseBigInteger(): Result = parse(::BigInteger) @InternalApi diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt new file mode 100644 index 000000000..e6dd81942 --- /dev/null +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt @@ -0,0 +1,416 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.serde.xml + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.content.BigDecimal +import aws.smithy.kotlin.runtime.content.BigInteger +import aws.smithy.kotlin.runtime.content.Document +import aws.smithy.kotlin.runtime.serde.* + +private const val FIRST_FIELD_INDEX: Int = 0 + +// Represents aspects of SdkFieldDescriptor that are particular to the Xml format +internal sealed class FieldLocation { + // specifies the mapping to a sdk field index + abstract val fieldIndex: Int + + data class Text(override val fieldIndex: Int) : FieldLocation() // Xml nodes have only one associated Text element + data class Attribute(override val fieldIndex: Int, val names: Set) : FieldLocation() +} + +/** + * Provides a deserializer for XML documents + * + * @param reader underlying [XmlStreamReader] from which tokens are read + * @param validateRootElement Flag indicating if the root XML document [XmlToken.BeginElement] should be validated against + * the descriptor passed to [deserializeStruct]. This only affects the root element, not nested struct elements. Some + * restXml based services DO NOT always send documents with a root element name that matches the shape ID name + * (S3 in particular). This means there is nothing in the model that gives you enough information to validate the tag. + */ +@Deprecated("XmlDeserializer is deprecated and will be removed in a future release") +@InternalApi +public class XmlDeserializer( + private val reader: XmlStreamReader, + private val validateRootElement: Boolean = false, +) : Deserializer { + + public constructor(input: ByteArray, validateRootElement: Boolean = false) : this(xmlStreamReader(input), validateRootElement) + + private var firstStructCall = true + + override fun deserializeStruct(descriptor: SdkObjectDescriptor): Deserializer.FieldIterator { + if (firstStructCall) { + if (!descriptor.hasTrait()) throw DeserializationException("Top-level struct $descriptor requires a XmlSerialName trait but has none.") + + firstStructCall = false + + reader.nextToken() // Matching field descriptors to children tags so consume the start element of top-level struct + + val structToken = if (descriptor.hasTrait()) { + reader.seek { it.name == descriptor.expectTrait().errorTag } + } else { + reader.seek() + } ?: throw DeserializationException("Could not find a begin element for new struct") + + if (validateRootElement) { + descriptor.requireNameMatch(structToken.name.tag) + } + } + + // Consume any remaining terminating tokens from previous deserialization + reader.seek() + + // Because attributes set on the root node of the struct, we must read the values before creating the subtree + val attribFields = reader.tokenAttributesToFieldLocations(descriptor) + val parentToken = if (reader.lastToken is XmlToken.BeginElement) { + reader.lastToken as XmlToken.BeginElement + } else { + throw DeserializationException("Expected last parsed token to be ${XmlToken.BeginElement::class} but was ${reader.lastToken}") + } + + val unwrapped = descriptor.hasTrait() + return XmlStructDeserializer(descriptor, reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CURRENT), parentToken, attribFields, unwrapped) + } + + override fun deserializeList(descriptor: SdkFieldDescriptor): Deserializer.ElementIterator { + val depth = when (descriptor.hasTrait()) { + true -> XmlStreamReader.SubtreeStartDepth.CURRENT + else -> XmlStreamReader.SubtreeStartDepth.CHILD + } + + return XmlListDeserializer(reader.subTreeReader(depth), descriptor) + } + + override fun deserializeMap(descriptor: SdkFieldDescriptor): Deserializer.EntryIterator { + val depth = when (descriptor.hasTrait()) { + true -> XmlStreamReader.SubtreeStartDepth.CURRENT + else -> XmlStreamReader.SubtreeStartDepth.CHILD + } + + return XmlMapDeserializer(reader.subTreeReader(depth), descriptor) + } +} + +/** + * Deserializes specific XML structures into forms that can produce Maps + * + * @param reader underlying [XmlStreamReader] from which tokens are read + * @param descriptor associated [SdkFieldDescriptor] which represents the expected Map + * @param primitiveDeserializer used to deserialize primitive values + */ +internal class XmlMapDeserializer( + private val reader: XmlStreamReader, + private val descriptor: SdkFieldDescriptor, + private val primitiveDeserializer: PrimitiveDeserializer = XmlPrimitiveDeserializer(reader, descriptor), +) : PrimitiveDeserializer by primitiveDeserializer, Deserializer.EntryIterator { + private val mapTrait = descriptor.findTrait() ?: XmlMapName.Default + + override fun hasNextEntry(): Boolean { + val compareTo = when (descriptor.hasTrait()) { + true -> descriptor.findTrait()?.name ?: mapTrait.key // Prefer seeking to XmlSerialName if the trait exists + false -> mapTrait.entry + } + + // Seek to either the XML serial name, entry, or key token depending on the flatness of the map and if the name trait is present + val nextEntryToken = when (descriptor.hasTrait()) { + true -> reader.peekSeek { it.name.local == compareTo } + false -> reader.seek { it.name.local == compareTo } + } + + return nextEntryToken != null + } + + override fun key(): String { + // Seek to the key begin token + reader.seek { it.name.local == mapTrait.key } + ?: error("Unable to find key $mapTrait.key in $descriptor") + + val keyValueToken = reader.takeNextAs() + reader.nextToken() // Consume the end wrapper + + return keyValueToken.value ?: throw DeserializationException("Key unspecified in $descriptor") + } + + override fun nextHasValue(): Boolean { + // Expect a begin and value (or another begin) token if Map entry has a value + val peekBeginToken = reader.peek(1) ?: throw DeserializationException("Unexpected termination of token stream in $descriptor") + val peekValueToken = reader.peek(2) ?: throw DeserializationException("Unexpected termination of token stream in $descriptor") + + return peekBeginToken !is XmlToken.EndElement && peekValueToken !is XmlToken.EndElement + } +} + +/** + * Deserializes specific XML structures into forms that can produce Lists + * + * @param reader underlying [XmlStreamReader] from which tokens are read + * @param descriptor associated [SdkFieldDescriptor] which represents the expected Map + * @param primitiveDeserializer used to deserialize primitive values + */ +internal class XmlListDeserializer( + private val reader: XmlStreamReader, + private val descriptor: SdkFieldDescriptor, + private val primitiveDeserializer: PrimitiveDeserializer = XmlPrimitiveDeserializer(reader, descriptor), +) : PrimitiveDeserializer by primitiveDeserializer, Deserializer.ElementIterator { + private var firstCall = true + private val flattened = descriptor.hasTrait() + private val elementName = (descriptor.findTrait() ?: XmlCollectionName.Default).element + + override fun hasNextElement(): Boolean { + if (!flattened && firstCall) { + val nextToken = reader.peek() + val matchedListDescriptor = nextToken is XmlToken.BeginElement && descriptor.nameMatches(nextToken.name.tag) + val hasChildren = if (nextToken == null) false else nextToken.depth >= reader.lastToken!!.depth + + if (!matchedListDescriptor && !hasChildren) return false + + // Discard the wrapper and move to the first element in the list + if (matchedListDescriptor) reader.nextToken() + + firstCall = false + } + + if (flattened) { + // Because our subtree is not CHILD, we cannot rely on the subtree boundary to determine end of collection. + // Rather, we search for either the next begin token matching the (flat) list member name which should + // be immediately after the current token + + // peek at the next token if there is one, in the case of a list of structs, the next token is actually + // the end of the current flat list element in which case we need to peek twice + val next = when (val peeked = reader.peek()) { + is XmlToken.EndElement -> { + if (peeked.name.local == descriptor.serialName.name) { + // consume the end token + reader.nextToken() + reader.peek() + } else { + peeked + } + } + else -> peeked + } + + val tokens = listOfNotNull(reader.lastToken, next) + + // Iterate over the token stream until begin token matching name is found or end element matching list is found. + return tokens + .filterIsInstance() + .any { it.name.local == descriptor.serialName.name } + } else { + // If we can find another begin token w/ the element name, we have more elements to process + return reader.seek { it.name.local == elementName }.isNotTerminal() + } + } + + override fun nextHasValue(): Boolean = reader.peek() !is XmlToken.EndElement +} + +/** + * Deserializes specific XML structures into forms that can produce structures + * + * @param objDescriptor associated [SdkObjectDescriptor] which represents the expected structure + * @param reader underlying [XmlStreamReader] from which tokens are read + * @param parentToken initial token of associated structure + * @param parsedFieldLocations list of [FieldLocation] representing values able to be loaded into deserialized instances + */ +private class XmlStructDeserializer( + private val objDescriptor: SdkObjectDescriptor, + reader: XmlStreamReader, + private val parentToken: XmlToken.BeginElement, + private val parsedFieldLocations: MutableList = mutableListOf(), + private val unwrapped: Boolean, +) : Deserializer.FieldIterator { + // Used to track direct deserialization or further nesting between calls to findNextFieldIndex() and deserialize() + private var reentryFlag: Boolean = false + + private val reader: XmlStreamReader = if (unwrapped) reader else reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CHILD) + + override fun findNextFieldIndex(): Int? { + if (unwrapped) { + return if (reader.peek() is XmlToken.Text) FIRST_FIELD_INDEX else null + } + if (inNestedMode()) { + // Returning from a nested struct call. Nested deserializer consumed + // tokens so clear them here to avoid processing stale state + parsedFieldLocations.clear() + } + + if (parsedFieldLocations.isEmpty()) { + val matchedFieldLocations = when (val token = reader.nextToken()) { + null, is XmlToken.EndDocument -> return null + is XmlToken.EndElement -> return findNextFieldIndex() + is XmlToken.BeginElement -> { + val nextToken = reader.peek() ?: return null + val objectFields = objDescriptor.fields + val memberFields = objectFields.filter { field -> objDescriptor.fieldTokenMatcher(field, token) } + val matchingFields = memberFields.mapNotNull { it.findFieldLocation(token, nextToken) } + matchingFields + } + else -> return findNextFieldIndex() + } + + // Sorting ensures attribs are processed before text, as processing the Text token pushes the parser on to the next token. + parsedFieldLocations.addAll(matchedFieldLocations.sortedBy { it is FieldLocation.Text }) + } + + return parsedFieldLocations.firstOrNull()?.fieldIndex ?: Deserializer.FieldIterator.UNKNOWN_FIELD + } + + private fun deserializeValue(transform: ((String) -> T)): T { + if (unwrapped) { + val value = reader.takeNextAs().value ?: "" + return transform(value) + } + // Set and validate mode + reentryFlag = false + if (parsedFieldLocations.isEmpty()) throw DeserializationException("matchedFields is empty, was findNextFieldIndex() called?") + + // Take the first FieldLocation and attempt to parse it into the value specified by the descriptor. + return when (val nextField = parsedFieldLocations.removeFirst()) { + is FieldLocation.Text -> { + val value = when (val peekToken = reader.peek()) { + is XmlToken.Text -> reader.takeNextAs().value ?: "" + is XmlToken.EndElement -> "" + else -> throw DeserializationException("Unexpected token $peekToken") + } + transform(value) + } + is FieldLocation.Attribute -> { + transform( + nextField + .names + .mapNotNull { parentToken.attributes[it] } + .firstOrNull() ?: throw DeserializationException("Expected attrib value ${nextField.names.first()} not found in ${parentToken.name}"), + ) + } + } + } + + override fun skipValue() = reader.skipNext() + + override fun deserializeByte(): Byte = deserializeValue { it.toIntOrNull()?.toByte() ?: throw DeserializationException("Unable to deserialize $it") } + + override fun deserializeInt(): Int = deserializeValue { it.toIntOrNull() ?: throw DeserializationException("Unable to deserialize $it") } + + override fun deserializeShort(): Short = deserializeValue { it.toIntOrNull()?.toShort() ?: throw DeserializationException("Unable to deserialize $it") } + + override fun deserializeLong(): Long = deserializeValue { it.toLongOrNull() ?: throw DeserializationException("Unable to deserialize $it") } + + override fun deserializeFloat(): Float = deserializeValue { it.toFloatOrNull() ?: throw DeserializationException("Unable to deserialize $it") } + + override fun deserializeDouble(): Double = deserializeValue { it.toDoubleOrNull() ?: throw DeserializationException("Unable to deserialize $it") } + + override fun deserializeBigInteger(): BigInteger = deserializeValue { + runCatching { BigInteger(it) } + .getOrElse { throw DeserializationException("Unable to deserialize $it as BigInteger") } + } + + override fun deserializeBigDecimal(): BigDecimal = deserializeValue { + runCatching { BigDecimal(it) } + .getOrElse { throw DeserializationException("Unable to deserialize $it as BigDecimal") } + } + + override fun deserializeString(): String = deserializeValue { it } + + override fun deserializeBoolean(): Boolean = deserializeValue { it.toBoolean() } + + override fun deserializeDocument(): Document { + throw DeserializationException("cannot deserialize unsupported Document type in xml") + } + + override fun deserializeNull(): Nothing? { + reader.takeNextAs() + return null + } + + // A struct deserializer can be called in two "modes": + // 1. to deserialize a value. This calls findNextFieldIndex() followed by deserialize() + // 2. to deserialize a nested container. This calls findNextFieldIndex() followed by a call to another deserialize() + // Because state is built in findNextFieldIndex() that is intended to be used directly in deserialize() (mode 1) + // and there is no explicit way that this type knows which mode is in use, the state built must be cleared. + // this is done by flipping a bit between the two calls. If the bit has not been flipped on any call to findNextFieldIndex() + // it is determined that the nested mode was used and any existing state should be cleared. + // if the state is not cleared, deserialization goes into an infinite loop because the deserializer sees pending fields to pull from the stream + // which are never consumed by the (missing) call to deserialize() + private fun inNestedMode(): Boolean = when (reentryFlag) { + true -> true + false -> { reentryFlag = true; false } + } +} + +// Extract the attributes from the last-read token and match them to [FieldLocation] on the [SdkObjectDescriptor]. +private fun XmlStreamReader.tokenAttributesToFieldLocations(descriptor: SdkObjectDescriptor): MutableList = + if (descriptor.hasXmlAttributes && lastToken is XmlToken.BeginElement) { + val attribFields = descriptor.fields.filter { it.hasTrait() } + val matchedAttribFields = attribFields.filter { it.findFieldLocation(lastToken as XmlToken.BeginElement, peek() ?: throw DeserializationException("Unexpected end of tokens")) != null } + matchedAttribFields.map { FieldLocation.Attribute(it.index, it.toQualifiedNames()) } + .toMutableList() + } else { + mutableListOf() + } + +// Returns a [FieldLocation] if the field maps to the current token +private fun SdkFieldDescriptor.findFieldLocation( + currentToken: XmlToken.BeginElement, + nextToken: XmlToken, +): FieldLocation? = when (val property = toFieldLocation()) { + is FieldLocation.Text -> { + when { + nextToken is XmlToken.Text -> property + nextToken is XmlToken.BeginElement -> property + // The following allows for struct primitives to remain unvisited if no value + // but causes nested deserializers to be called even if they contain no value + nextToken is XmlToken.EndElement && currentToken.name == nextToken.name -> property + else -> null + } + } + is FieldLocation.Attribute -> { + val foundMatch = property.names.any { currentToken.attributes[it]?.isNotBlank() == true } + if (foundMatch) property else null + } +} + +// Produce a [FieldLocation] type based on presence of traits of field +// A field without an attribute trait is assumed to be a text token +private fun SdkFieldDescriptor.toFieldLocation(): FieldLocation = + when (findTrait()) { + null -> FieldLocation.Text(index) // Assume a text value if no attributes defined. + else -> FieldLocation.Attribute(index, toQualifiedNames()) + } + +// Matches fields and tokens with matching qualified name +private fun SdkObjectDescriptor.fieldTokenMatcher(fieldDescriptor: SdkFieldDescriptor, beginElement: XmlToken.BeginElement): Boolean { + if (fieldDescriptor.kind == SerialKind.List && fieldDescriptor.hasTrait()) { + val fieldName = fieldDescriptor.findTrait() ?: XmlCollectionName.Default + val tokenQname = beginElement.name + + // It may be that we are matching a flattened list element or matching a list itself. In the latter + // case the following predicate will not work, so if we fail to match the member + // try again (below) to match against the container. + if (fieldName.element == tokenQname.local) return true + } + + return fieldDescriptor.nameMatches(beginElement.name.tag) +} + +/** + * Return the next token of the specified type or throw [DeserializationException] if incorrect type. + */ +internal inline fun XmlStreamReader.takeNextAs(): TExpected { + val token = this.nextToken() ?: throw DeserializationException("Expected ${TExpected::class} but instead found null") + requireToken(token) + return token as TExpected +} + +/** + * Require that the given token be of type [TExpected] or else throw an exception + */ +internal inline fun requireToken(token: XmlToken) { + if (token::class != TExpected::class) { + throw DeserializationException("Expected ${TExpected::class}; found ${token::class} ($token)") + } +} diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt new file mode 100644 index 000000000..00f124d19 --- /dev/null +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde.xml + +import aws.smithy.kotlin.runtime.content.BigDecimal +import aws.smithy.kotlin.runtime.content.BigInteger +import aws.smithy.kotlin.runtime.content.Document +import aws.smithy.kotlin.runtime.serde.* + +/** + * Deserialize primitive values for single values, lists, and maps + */ +internal class XmlPrimitiveDeserializer(private val reader: XmlStreamReader, private val fieldDescriptor: SdkFieldDescriptor) : + PrimitiveDeserializer { + + constructor(input: ByteArray, fieldDescriptor: SdkFieldDescriptor) : this(xmlStreamReader(input), fieldDescriptor) + + private fun deserializeValue(transform: ((String) -> T)): T { + if (reader.peek() is XmlToken.BeginElement) { + // In the case of flattened lists, we "fall" into the first member as there is no wrapper. + // this conditional checks that case for the first element of the list. + val wrapperToken = reader.takeNextAs() + if (wrapperToken.name.local != fieldDescriptor.generalName()) { + // Depending on flat/not-flat, may need to consume multiple start tokens + return deserializeValue(transform) + } + } + + val token = reader.takeNextAs() + + return token.value + ?.let { transform(it) } + ?.also { reader.takeNextAs() } ?: throw DeserializationException("$token specifies nonexistent or invalid value.") + } + + override fun deserializeByte(): Byte = deserializeValue { it.toIntOrNull()?.toByte() ?: throw DeserializationException("Unable to deserialize $it as Byte") } + + override fun deserializeInt(): Int = deserializeValue { it.toIntOrNull() ?: throw DeserializationException("Unable to deserialize $it as Int") } + + override fun deserializeShort(): Short = deserializeValue { it.toIntOrNull()?.toShort() ?: throw DeserializationException("Unable to deserialize $it as Short") } + + override fun deserializeLong(): Long = deserializeValue { it.toLongOrNull() ?: throw DeserializationException("Unable to deserialize $it as Long") } + + override fun deserializeFloat(): Float = deserializeValue { it.toFloatOrNull() ?: throw DeserializationException("Unable to deserialize $it as Float") } + + override fun deserializeDouble(): Double = deserializeValue { it.toDoubleOrNull() ?: throw DeserializationException("Unable to deserialize $it as Double") } + + override fun deserializeBigInteger(): BigInteger = deserializeValue { + runCatching { BigInteger(it) } + .getOrElse { throw DeserializationException("Unable to deserialize $it as BigInteger") } + } + + override fun deserializeBigDecimal(): BigDecimal = deserializeValue { + runCatching { BigDecimal(it) } + .getOrElse { throw DeserializationException("Unable to deserialize $it as BigDecimal") } + } + + override fun deserializeString(): String = deserializeValue { it } + + override fun deserializeBoolean(): Boolean = deserializeValue { it.toBoolean() } + + override fun deserializeDocument(): Document { + throw DeserializationException("cannot deserialize unsupported Document type in xml") + } + + override fun deserializeNull(): Nothing? { + reader.nextToken() ?: throw DeserializationException("Unexpected end of stream") + reader.seek() + reader.nextToken() ?: throw DeserializationException("Unexpected end of stream") + + return null + } +} diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt index 71af12869..a005be4c4 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt @@ -58,11 +58,6 @@ public interface XmlStreamReader { */ public fun skipNext() - /** - * Recursively skip the current token. Meant for discarding unwanted/unrecognized nodes in an XML document - */ - public fun skipCurrent() - /** * Peek at the next token type. Successive calls will return the same value, meaning there is only one * look-ahead at any given time during the parsing of input data. diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt index da486ab53..338e2944f 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt @@ -5,22 +5,28 @@ package aws.smithy.kotlin.runtime.serde.xml import aws.smithy.kotlin.runtime.InternalApi -import aws.smithy.kotlin.runtime.io.Closeable import aws.smithy.kotlin.runtime.serde.DeserializationException /** * An [XmlStreamReader] scoped to reading a single XML element [tag] - * [XmlTagReader] provides a "tag" scoped view into an XML document. Methods return + * XmlTagReader provides a "tag" scoped view into an XML document. Methods return * `null` when the current tag has been exhausted. */ @InternalApi public class XmlTagReader( public val tag: XmlToken.BeginElement, private val reader: XmlStreamReader, -) : Closeable { - private var last: XmlTagReader? = null +) { + // last tag we emitted and returned to the caller + private var lastEmitted: XmlTagReader? = null private var closed = false + /** + * Get the fully qualified tag name of [tag] + */ + public val tagName: String + get() = tag.name.toString() + /** * Return the next actionable token or null if stream is exhausted. */ @@ -45,8 +51,6 @@ public class XmlTagReader( return reader.peek() !is XmlToken.EndElement } - override fun close(): Unit = drop() - /** * Exhaust this [XmlTagReader] to completion. This should always * be invoked to maintain deserialization state. @@ -58,10 +62,11 @@ public class XmlTagReader( } /** - * Return an [XmlTagReader] for the next [XmlToken.BeginElement] + * Return an [XmlTagReader] for the next [XmlToken.BeginElement]. The returned reader + * is only valid until [nextTag] is called or [drop] is invoked on it, whichever comes first. */ public fun nextTag(): XmlTagReader? { - last?.drop() + lastEmitted?.drop() var cand = nextToken() while (cand != null && cand !is XmlToken.BeginElement) { @@ -71,7 +76,7 @@ public class XmlTagReader( val nextTok = cand as? XmlToken.BeginElement return nextTok?.tagReader(reader).also { newScope -> - last = newScope + lastEmitted = newScope } } } @@ -95,12 +100,12 @@ private fun XmlStreamReader.root(): XmlTagReader { @InternalApi public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): XmlTagReader { val start = reader.lastToken as? XmlToken.BeginElement ?: error("expected start tag found ${reader.lastToken}") - check(qualifiedName == start.qualifiedName) { "expected start tag $qualifiedName but current reader state is on ${start.qualifiedName}" } + check(name == start.name) { "expected start tag $name but current reader state is on ${start.name}" } return XmlTagReader(this, reader) } /** - * Unwrap the next token as [XmlToken.Text] and return its' value or throw a [DeserializationException] + * Unwrap the next token as [XmlToken.Text] and return its value or throw a [DeserializationException] */ @InternalApi public fun XmlTagReader.data(): String = diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt index 600a7bf13..8358ab37e 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt @@ -48,6 +48,9 @@ public sealed class XmlToken { return QualifiedName(local, prefix) } } + + val tag: String + get() = toString() } /** @@ -56,7 +59,7 @@ public sealed class XmlToken { @InternalApi public data class BeginElement( override val depth: Int, - public val qualifiedName: QualifiedName, + public val name: QualifiedName, public val attributes: Map = emptyMap(), public val nsDeclarations: List = emptyList(), ) : XmlToken() { @@ -67,33 +70,21 @@ public sealed class XmlToken { // Convenience constructor for name-only nodes with attributes. public constructor(depth: Int, name: String, attributes: Map) : this(depth, QualifiedName(name), attributes) - override fun toString(): String = "<${this.qualifiedName} (${this.depth})>" + override fun toString(): String = "<$name ($depth)>" // convenience function for codegen public fun getAttr(qualified: String): String? = attributes[QualifiedName.from(qualified)] - - /** - * Get the qualified tag name of this element - */ - val name: String - get() = qualifiedName.toString() } /** * The closing of an XML element */ @InternalApi - public data class EndElement(override val depth: Int, public val qualifiedName: QualifiedName) : XmlToken() { + public data class EndElement(override val depth: Int, public val name: QualifiedName) : XmlToken() { // Convenience constructor for name-only nodes. public constructor(depth: Int, name: String) : this(depth, QualifiedName(name)) - override fun toString(): String = " (${this.depth})" - - /** - * Get the qualified tag name of this element - */ - val name: String - get() = qualifiedName.toString() + override fun toString(): String = " ($depth)" } /** @@ -101,7 +92,7 @@ public sealed class XmlToken { */ @InternalApi public data class Text(override val depth: Int, public val value: String?) : XmlToken() { - override fun toString(): String = "${this.value} (${this.depth})" + override fun toString(): String = "$value ($depth)" } @InternalApi @@ -120,9 +111,9 @@ public sealed class XmlToken { } override fun toString(): String = when (this) { - is BeginElement -> "<${this.qualifiedName}>" - is EndElement -> "" - is Text -> "${this.value}" + is BeginElement -> "<$name>" + is EndElement -> "" + is Text -> "$value" StartDocument -> "[StartDocument]" EndDocument -> "[EndDocument]" } @@ -142,7 +133,7 @@ internal fun XmlToken?.terminates(beginToken: XmlToken?): Boolean { if (this !is XmlToken.EndElement) return false if (beginToken !is XmlToken.BeginElement) return false if (depth != beginToken.depth) return false - if (qualifiedName != beginToken.qualifiedName) return false + if (name != beginToken.name) return false return true } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt index 35d28cd71..a241d9578 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt @@ -52,9 +52,6 @@ public class LexingXmlStreamReader(private val source: XmlLexer) : XmlStreamRead else -> scanUntilDepth(startDepth, nextToken()) // Keep scannin'! } } - override fun skipCurrent() { - scanUntilDepth(lastToken?.depth ?: 0, lastToken) - } override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = if (peek(1).terminates(lastToken)) { @@ -112,8 +109,6 @@ private class ChildXmlStreamReader( override fun skipNext() = parent.skipNext() - override fun skipCurrent() = parent.skipCurrent() - override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = parent.subTreeReader(subtreeStartDepth) } @@ -131,10 +126,7 @@ private class EmptyXmlStreamReader(private val parent: XmlStreamReader?) : XmlSt override fun peek(index: Int): XmlToken? = null override fun skipNext() = Unit - override fun skipCurrent() = Unit override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = this } private fun List.getOrNull(index: Int): T? = if (index < size) this[index] else null - -internal fun XmlStreamReader.emptyReader(parent: XmlStreamReader? = this): XmlStreamReader = EmptyXmlStreamReader(parent) diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt index 52496f89f..7959ae783 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt @@ -46,7 +46,7 @@ public class XmlNode { return parseDom(reader) } - internal fun fromToken(token: XmlToken.BeginElement): XmlNode = XmlNode(token.qualifiedName).apply { + internal fun fromToken(token: XmlToken.BeginElement): XmlNode = XmlNode(token.name).apply { attributes.putAll(token.attributes) namespaces.addAll(token.nsDeclarations) } @@ -83,8 +83,8 @@ public fun parseDom(reader: XmlStreamReader): XmlNode { is XmlToken.EndElement -> { val curr = nodeStack.top() - if (curr.name != token.qualifiedName) { - throw DeserializationException("expected end of element: `${curr.name}`, found: `${token.qualifiedName}`") + if (curr.name != token.name) { + throw DeserializationException("expected end of element: `${curr.name}`, found: `${token.name}`") } if (nodeStack.count() > 1) { diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt index b02969b09..26f129c8e 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt @@ -112,7 +112,7 @@ class XmlStreamReaderTest { assertEquals(6, actual.size) assertIs(actual.first()) - assertEquals("payload", (actual.first() as XmlToken.BeginElement).qualifiedName.local) + assertEquals("payload", (actual.first() as XmlToken.BeginElement).name.local) } @Test @@ -193,7 +193,7 @@ class XmlStreamReaderTest { assertEquals(expected, actual) } - private fun skipTest(skipCurrent: Boolean) { + private fun skipTest() { val payload = """ 1> @@ -224,30 +224,19 @@ class XmlStreamReaderTest { nextToken() // end x } - val nt = if (skipCurrent) { - reader.nextToken() - } else { - reader.peek() - } + val nt = reader.peek() assertIs(nt) - assertEquals("unknown", nt.qualifiedName.local) + assertEquals("unknown", nt.name.local) - if (skipCurrent) { - reader.skipCurrent() - } else { - reader.skipNext() - } + reader.skipNext() val y = reader.nextToken() as XmlToken.BeginElement - assertEquals("y", y.qualifiedName.local) + assertEquals("y", y.name.local) } @Test - fun itSkipsNextValuesRecursively() = skipTest(false) - - @Test - fun itSkipsCurrentValuesRecursively() = skipTest(true) + fun itSkipsNextValuesRecursively() = skipTest() @Test fun itSkipsSimpleValues() { @@ -269,11 +258,11 @@ class XmlStreamReaderTest { assertIs(reader.peek()) val zElement = reader.nextToken() as XmlToken.BeginElement - assertEquals("z", zElement.qualifiedName.local) + assertEquals("z", zElement.name.local) reader.skipNext() val yElement = reader.nextToken() as XmlToken.BeginElement - assertEquals("y", yElement.qualifiedName.local) + assertEquals("y", yElement.name.local) } @Test @@ -312,7 +301,7 @@ class XmlStreamReaderTest { assertNull(reader.lastToken, "Expected to start with null lastToken") var peekedToken = reader.peek() assertIs(peekedToken) - assertEquals("l1", peekedToken.qualifiedName.local) + assertEquals("l1", peekedToken.name.local) assertNull(reader.lastToken, "Expected peek to not effect lastToken") reader.nextToken() // consumed l1 assertEquals(1, reader.lastToken?.depth, "Expected level 1") @@ -320,14 +309,14 @@ class XmlStreamReaderTest { peekedToken = reader.nextToken() // consumed l2 assertEquals(2, reader.lastToken?.depth, "Expected level 2") assertIs(peekedToken) - assertEquals("l2", peekedToken.qualifiedName.local) + assertEquals("l2", peekedToken.name.local) reader.peek() assertEquals(2, reader.lastToken?.depth, "Expected peek to not effect level") peekedToken = reader.nextToken() assertEquals(3, reader.lastToken?.depth, "Expected level 3") assertIs(peekedToken) - assertEquals("l3", peekedToken.qualifiedName.local) + assertEquals("l3", peekedToken.name.local) reader.peek() assertEquals(3, reader.lastToken?.depth, "Expected peek to not effect level") } @@ -459,7 +448,7 @@ class XmlStreamReaderTest { val token = unit.nextToken() assertIs(token) - assertEquals("root", token.qualifiedName.local) + assertEquals("root", token.name.local) var subTree1 = unit.subTreeReader() var subTree1Elements = subTree1.allTokens() @@ -575,25 +564,25 @@ class XmlStreamReaderTest { val rTokenTake = actual.nextToken() assertIs(rTokenPeek) - assertEquals("r", rTokenPeek.qualifiedName.local) + assertEquals("r", rTokenPeek.name.local) assertIs(aToken) - assertEquals("a", aToken.qualifiedName.local) + assertEquals("a", aToken.name.local) assertIs(rTokenTake) - assertEquals("r", rTokenTake.qualifiedName.local) + assertEquals("r", rTokenTake.name.local) val bToken = actual.peek(2) assertIs(bToken) - assertEquals("b", bToken.qualifiedName.local) + assertEquals("b", bToken.name.local) val aTokenTake = actual.nextToken() assertIs(aTokenTake) - assertEquals("a", aTokenTake.qualifiedName.local) + assertEquals("a", aTokenTake.name.local) val aCloseToken = actual.peek(5) // 1: 2: 3: 4: 5: assertIs(aCloseToken) - assertEquals("a", aCloseToken.qualifiedName.local) + assertEquals("a", aCloseToken.name.local) val restOfTokens = actual.allTokens() assertEquals(restOfTokens.size, 6) @@ -621,12 +610,12 @@ class XmlStreamReaderTest { // match begin node of depth 2 val l2Node = unit.seek { it.depth == 2 } assertIs(l2Node) - assertEquals("a", l2Node.qualifiedName.local) + assertEquals("a", l2Node.name.local) // verify next token is correct val nextNode = unit.nextToken() assertIs(nextNode) - assertEquals("b", nextNode.qualifiedName.local) + assertEquals("b", nextNode.name.local) // verify no match produces null unit = xmlStreamReader(payload) diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt index 7200ef892..2e0c07b29 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt @@ -118,7 +118,7 @@ class XmlTagReaderTest { val decoder = xmlTagReader(payload) loop@while (true) { val curr = decoder.nextTag() ?: break@loop - when (curr.tag.name) { + when (curr.tagName) { "Child1" -> { assertEquals(1, curr.nextTag()?.data()?.parseInt()?.getOrNull()) assertEquals(2, curr.nextTag()?.data()?.parseInt()?.getOrNull()) From 822b727c82dd154d9b890f15363e0df5e19ea445 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Tue, 27 Feb 2024 12:39:43 -0500 Subject: [PATCH 22/25] fix debug comment --- .../smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index 1e1bcb3e2..b5f2c7e92 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -268,7 +268,7 @@ open class XmlParserGenerator( payloadMembers.forEach { member -> val name = member.getTrait()?.value ?: member.memberName - write("// ${member.memberName} ${escape(member.id.toString())}") + writeMemberDebugComment(ctx, member) writeInline("#S -> builder.#L = ", name, ctx.symbolProvider.toMemberName(member)) deserializeMember(ctx, innerCtx, member, writer) } From feca351024049e9d5c03ef10b169b960dd8ce818 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Tue, 27 Feb 2024 14:47:38 -0500 Subject: [PATCH 23/25] regenerate api dump --- .../api/aws-xml-protocols.api | 4 ++-- runtime/serde/serde-xml/api/serde-xml.api | 23 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api b/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api index ab7db70f5..d92c49cc8 100644 --- a/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api +++ b/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api @@ -1,8 +1,8 @@ public final class aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerKt { - public static final fun parseEc2QueryErrorResponse ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; + public static final fun parseEc2QueryErrorResponse ([BLkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerKt { - public static final fun parseRestXmlErrorResponse ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; + public static final fun parseRestXmlErrorResponse ([BLkotlin/coroutines/Continuation;)Ljava/lang/Object; } diff --git a/runtime/serde/serde-xml/api/serde-xml.api b/runtime/serde/serde-xml/api/serde-xml.api index 2b7c80a66..7ed0bcdca 100644 --- a/runtime/serde/serde-xml/api/serde-xml.api +++ b/runtime/serde/serde-xml/api/serde-xml.api @@ -47,6 +47,16 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlCollectionValueNamespa public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V } +public final class aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer : aws/smithy/kotlin/runtime/serde/Deserializer { + public fun (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;Z)V + public synthetic fun (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun ([BZ)V + public synthetic fun ([BZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun deserializeList (Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$ElementIterator; + public fun deserializeMap (Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$EntryIterator; + public fun deserializeStruct (Laws/smithy/kotlin/runtime/serde/SdkObjectDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$FieldIterator; +} + public final class aws/smithy/kotlin/runtime/serde/xml/XmlError : aws/smithy/kotlin/runtime/serde/FieldTrait { public static final field INSTANCE Laws/smithy/kotlin/runtime/serde/xml/XmlError; public final fun getErrorTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; @@ -143,7 +153,6 @@ public abstract interface class aws/smithy/kotlin/runtime/serde/xml/XmlStreamRea public abstract fun getLastToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun peek (I)Laws/smithy/kotlin/runtime/serde/xml/XmlToken; - public abstract fun skipCurrent ()V public abstract fun skipNext ()V public abstract fun subTreeReader (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$SubtreeStartDepth;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader; } @@ -190,11 +199,11 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterKt { public static synthetic fun xmlStreamWriter$default (ZILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; } -public final class aws/smithy/kotlin/runtime/serde/xml/XmlTagReader : java/io/Closeable { +public final class aws/smithy/kotlin/runtime/serde/xml/XmlTagReader { public fun (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;)V - public fun close ()V public final fun drop ()V public final fun getTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; + public final fun getTagName ()Ljava/lang/String; public final fun nextHasValue ()Z public final fun nextTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; public final fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; @@ -227,9 +236,8 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement : a public final fun getAttr (Ljava/lang/String;)Ljava/lang/String; public final fun getAttributes ()Ljava/util/Map; public fun getDepth ()I - public final fun getName ()Ljava/lang/String; + public final fun getName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; public final fun getNsDeclarations ()Ljava/util/List; - public final fun getQualifiedName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -248,8 +256,7 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$EndElement : aws public static synthetic fun copy$default (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$EndElement;ILaws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$EndElement; public fun equals (Ljava/lang/Object;)Z public fun getDepth ()I - public final fun getName ()Ljava/lang/String; - public final fun getQualifiedName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; + public final fun getName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -279,6 +286,7 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName { public fun equals (Ljava/lang/Object;)Z public final fun getLocal ()Ljava/lang/String; public final fun getPrefix ()Ljava/lang/String; + public final fun getTag ()Ljava/lang/String; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -314,7 +322,6 @@ public final class aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXml public fun getLastToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public fun peek (I)Laws/smithy/kotlin/runtime/serde/xml/XmlToken; - public fun skipCurrent ()V public fun skipNext ()V public fun subTreeReader (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$SubtreeStartDepth;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader; } From ac7f246600232bf15b08104ac5aeb3d09593f236 Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Tue, 27 Feb 2024 16:36:09 -0500 Subject: [PATCH 24/25] fix map key type --- .../kotlin/codegen/rendering/serde/XmlParserGenerator.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index b5f2c7e92..f99bae49e 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -488,7 +488,7 @@ open class XmlParserGenerator( writer: KotlinWriter, ) { val target = ctx.model.expectShape(member.target) - val keySymbol = ctx.symbolProvider.toSymbol(target.key) + val keySymbol = KotlinTypes.String val valueSymbol = ctx.symbolProvider.toSymbol(target.value) val isSparse = target.hasTrait() writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) @@ -512,7 +512,7 @@ open class XmlParserGenerator( map: MapShape, ): Symbol { val shapeName = StringUtils.capitalize(map.id.getName(ctx.service)) - val keySymbol = ctx.symbolProvider.toSymbol(map.key) + val keySymbol = KotlinTypes.String val valueSymbol = ctx.symbolProvider.toSymbol(map.value) val isSparse = map.hasTrait() val serdeCtx = SerdeCtx("reader") From 1b1d0a52e8dcc4a6d3b2540c874b0e4972224d5a Mon Sep 17 00:00:00 2001 From: Aaron J Todd Date: Tue, 27 Feb 2024 20:49:11 -0500 Subject: [PATCH 25/25] really fix enum key types --- .../rendering/serde/XmlParserGenerator.kt | 10 +++- tests/codegen/serde-tests/model/shared.smithy | 7 +++ .../smithy/kotlin/tests/serde/XmlMapTest.kt | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index f99bae49e..4c4e59894 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -463,7 +463,7 @@ open class XmlParserGenerator( writer: KotlinWriter, ) { val target = ctx.model.expectShape(member.target) - val keySymbol = ctx.symbolProvider.toSymbol(target.key) + val keySymbol = KotlinTypes.String val valueSymbol = ctx.symbolProvider.toSymbol(target.value) writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) val isSparse = target.hasTrait() @@ -541,6 +541,14 @@ open class XmlParserGenerator( val keyName = map.key.getTrait()?.value ?: map.key.memberName writeInline("#S -> key = ", keyName) deserializeMember(ctx, innerCtx, map.key, this) + // FIXME - We re-use deserializeMember here but key types targeting enums + // have to pull the raw string value back out because of + // https://github.com/awslabs/smithy-kotlin/issues/1045 + val targetValueShape = ctx.model.expectShape(map.key.target) + if (targetValueShape.type == ShapeType.ENUM) { + writer.indent() + .write(".value") + } val valueName = map.value.getTrait()?.value ?: map.value.memberName if (isSparse) { diff --git a/tests/codegen/serde-tests/model/shared.smithy b/tests/codegen/serde-tests/model/shared.smithy index 5b89a7f2a..042ce8c09 100644 --- a/tests/codegen/serde-tests/model/shared.smithy +++ b/tests/codegen/serde-tests/model/shared.smithy @@ -62,6 +62,11 @@ map FooEnumMap { value: FooEnum, } +map FooEnumKeyMap { + key: FooEnum, + value: Integer +} + @timestampFormat("date-time") timestamp DateTime @@ -134,6 +139,8 @@ structure MapTypesMixin { sparseMap: SparseStringMap, nestedMap: NestedStringMap, listMap: StringListMap, + enumValueMap: FooEnumMap, + enumKeyMap: FooEnumKeyMap, } @mixin diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt index 9909dbd40..530df9186 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt @@ -218,4 +218,60 @@ class XmlMapTest : AbstractXmlTest() { val actualDeserialized = deserializeStructTypeDocument(reader) assertEquals(expected, actualDeserialized) } + + @Test + fun testEnumValueMap() { + val expected = StructType { + enumValueMap = mapOf( + "foo" to FooEnum.Foo, + "bar" to FooEnum.Bar, + ) + } + val payload = """ + + + + foo + Foo + + + bar + Bar + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testEnumKeyMap() { + // see also https://github.com/awslabs/smithy-kotlin/issues/1045 + val expected = StructType { + enumKeyMap = mapOf( + FooEnum.Foo.value to 1, + "Bar" to 2, + "Unknown" to 3, + ) + } + val payload = """ + + + + Foo + 1 + + + Bar + 2 + + + Unknown + 3 + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } }