Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ignore __type when deserializing union for AWS Json protocols #964

Merged
merged 6 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ object RuntimeTypes {
val JsonSerialName = symbol("JsonSerialName")
val JsonSerializer = symbol("JsonSerializer")
val JsonDeserializer = symbol("JsonDeserializer")
val IgnoreKey = symbol("IgnoreKey")
}

object SerdeXml : RuntimeTypePackage(KotlinDependency.SERDE_XML) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ open class JsonParserGenerator(

open val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

open fun descriptorGenerator(
ctx: ProtocolGenerator.GenerationContext,
shape: Shape,
members: List<MemberShape>,
writer: KotlinWriter,
): JsonSerdeDescriptorGenerator = JsonSerdeDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members, supportsJsonNameTrait)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: Line length > 120 chars


override fun operationDeserializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List<MemberShape>): Symbol {
val outputSymbol = op.output.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) }
return op.bodyDeserializer(ctx.settings) { writer ->
Expand Down Expand Up @@ -127,7 +134,7 @@ open class JsonParserGenerator(
members: List<MemberShape>,
writer: KotlinWriter,
) {
JsonSerdeDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members, supportsJsonNameTrait).render()
descriptorGenerator(ctx, shape, members, writer).render()
if (shape.isUnionShape) {
val name = ctx.symbolProvider.toSymbol(shape).name
DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,13 @@ private class JsonFieldIterator(
val token = reader.nextTokenOf<JsonToken.Name>()
val propertyName = token.value
val field = descriptor.fields.find { it.serialName == propertyName }
field?.index ?: Deserializer.FieldIterator.UNKNOWN_FIELD

if (descriptor.traits.contains(IgnoreKey(propertyName))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: Consider if (IgnoreKey(propertyName) in descriptor.traits)

reader.skipNext() // the value of the ignored key
return findNextFieldIndex()
} else {
field?.index ?: Deserializer.FieldIterator.UNKNOWN_FIELD
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ public data class JsonSerialName(public val name: String) : FieldTrait
@InternalApi
public val SdkFieldDescriptor.serialName: String
get() = expectTrait<JsonSerialName>().name

/**
* Indicates to deserializers to ignore field/key
*/
@InternalApi
public data class IgnoreKey(public val key: String) : FieldTrait
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package aws.smithy.kotlin.runtime.serde.json

import aws.smithy.kotlin.runtime.serde.SdkFieldDescriptor
import aws.smithy.kotlin.runtime.serde.SdkObjectDescriptor
import aws.smithy.kotlin.runtime.serde.SerialKind
import aws.smithy.kotlin.runtime.serde.deserializeStruct
import kotlin.test.Test
import kotlin.test.assertEquals

class JsonDeserializerIgnoresKeysTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Every test has explicitly modeled field descriptors. You're missing coverage for unmodeled fields that we want to skip rather than enumerate as unknown.

class IgnoresKeysTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this double-nested class seems unnecessary, I think you can just declare these at the top-level without using a companion object for simplicity

companion object {
val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("x"))
val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("y"))
val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("z"))
val OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
trait(IgnoreKey("z")) // <----
field(X_DESCRIPTOR)
field(Y_DESCRIPTOR)
field(Z_DESCRIPTOR)
}
}
}

@Test
fun itIgnoresKeys() {
val payload = """
{
"x": 1,
"y": 2,
"z": 3
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(IgnoresKeysTest.OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
IgnoresKeysTest.X_DESCRIPTOR.index -> x = deserializeInt()
IgnoresKeysTest.Y_DESCRIPTOR.index -> y = deserializeInt()
IgnoresKeysTest.Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}

@Test
fun itIgnoresKeysOutOfOrder() {
val payload = """
{
"z": 3,
"x": 1,
"y": 2
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(IgnoresKeysTest.OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
IgnoresKeysTest.X_DESCRIPTOR.index -> x = deserializeInt()
IgnoresKeysTest.Y_DESCRIPTOR.index -> y = deserializeInt()
IgnoresKeysTest.Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}

@Test
fun itIgnoresKeysManyTimes() {
val payload = """
{
"x": 1,
"y": 2,
"z": 3,
"z": 3,
"z": 3
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(IgnoresKeysTest.OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
IgnoresKeysTest.X_DESCRIPTOR.index -> x = deserializeInt()
IgnoresKeysTest.Y_DESCRIPTOR.index -> y = deserializeInt()
IgnoresKeysTest.Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}

class IgnoresMultipleKeysTest {
companion object {
val W_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("w"))
val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("x"))
val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("y"))
val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("z"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated, can re-use the others descriptors above

val OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
trait(IgnoreKey("w")) // <----
trait(IgnoreKey("z")) // <----
field(W_DESCRIPTOR)
field(X_DESCRIPTOR)
field(Y_DESCRIPTOR)
field(Z_DESCRIPTOR)
}
}
}

@Test
fun itIgnoresMultipleKeys() {
val payload = """
{
"w": 0,
"x": 1,
"y": 2,
"z": 3
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var w: Int? = null
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(IgnoresMultipleKeysTest.OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
IgnoresMultipleKeysTest.W_DESCRIPTOR.index -> w = deserializeInt()
IgnoresMultipleKeysTest.X_DESCRIPTOR.index -> x = deserializeInt()
IgnoresMultipleKeysTest.Y_DESCRIPTOR.index -> y = deserializeInt()
IgnoresMultipleKeysTest.Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(null, w)
assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}
}
Loading