Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ignore __type when deserializing union for AWS Json protocols #964

Merged
merged 6 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ object RuntimeTypes {
val JsonSerialName = symbol("JsonSerialName")
val JsonSerializer = symbol("JsonSerializer")
val JsonDeserializer = symbol("JsonDeserializer")
val IgnoreKey = symbol("IgnoreKey")
}

object SerdeXml : RuntimeTypePackage(KotlinDependency.SERDE_XML) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ open class JsonParserGenerator(

open val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

open fun descriptorGenerator(
ctx: ProtocolGenerator.GenerationContext,
shape: Shape,
members: List<MemberShape>,
writer: KotlinWriter,
): JsonSerdeDescriptorGenerator = JsonSerdeDescriptorGenerator(
ctx.toRenderingContext(protocolGenerator, shape, writer),
members,
supportsJsonNameTrait,
)

override fun operationDeserializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List<MemberShape>): Symbol {
val outputSymbol = op.output.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) }
return op.bodyDeserializer(ctx.settings) { writer ->
Expand Down Expand Up @@ -127,7 +138,7 @@ open class JsonParserGenerator(
members: List<MemberShape>,
writer: KotlinWriter,
) {
JsonSerdeDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members, supportsJsonNameTrait).render()
descriptorGenerator(ctx, shape, members, writer).render()
if (shape.isUnionShape) {
val name = ctx.symbolProvider.toSymbol(shape).name
DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,13 @@ private class JsonFieldIterator(
val token = reader.nextTokenOf<JsonToken.Name>()
val propertyName = token.value
val field = descriptor.fields.find { it.serialName == propertyName }
field?.index ?: Deserializer.FieldIterator.UNKNOWN_FIELD

if (IgnoreKey(propertyName) in descriptor.traits) {
reader.skipNext() // the value of the ignored key
return findNextFieldIndex()
} else {
field?.index ?: Deserializer.FieldIterator.UNKNOWN_FIELD
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ public data class JsonSerialName(public val name: String) : FieldTrait
@InternalApi
public val SdkFieldDescriptor.serialName: String
get() = expectTrait<JsonSerialName>().name

/**
* Indicates to deserializers to ignore field/key
*/
@InternalApi
public data class IgnoreKey(public val key: String) : FieldTrait
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package aws.smithy.kotlin.runtime.serde.json

import aws.smithy.kotlin.runtime.serde.SdkFieldDescriptor
import aws.smithy.kotlin.runtime.serde.SdkObjectDescriptor
import aws.smithy.kotlin.runtime.serde.SerialKind
import aws.smithy.kotlin.runtime.serde.deserializeStruct
import kotlin.test.Test
import kotlin.test.assertEquals

class JsonDeserializerIgnoresKeysTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Every test has explicitly modeled field descriptors. You're missing coverage for unmodeled fields that we want to skip rather than enumerate as unknown.

private val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("x"))
private val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("y"))
private val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("z"))
private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
trait(IgnoreKey("z"))
field(X_DESCRIPTOR)
field(Y_DESCRIPTOR)
field(Z_DESCRIPTOR)
}

@Test
fun itIgnoresKeys() {
val payload = """
{
"x": 1,
"y": 2,
"z": 3
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
X_DESCRIPTOR.index -> x = deserializeInt()
Y_DESCRIPTOR.index -> y = deserializeInt()
Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}

@Test
fun itIgnoresKeysOutOfOrder() {
val payload = """
{
"z": 3,
"x": 1,
"y": 2
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
X_DESCRIPTOR.index -> x = deserializeInt()
Y_DESCRIPTOR.index -> y = deserializeInt()
Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}

@Test
fun itIgnoresKeysManyTimes() {
val payload = """
{
"x": 1,
"y": 2,
"z": 3,
"z": 3,
"z": 3
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
X_DESCRIPTOR.index -> x = deserializeInt()
Y_DESCRIPTOR.index -> y = deserializeInt()
Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}

private val W_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("w"))
private val MULT_KEYS_OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
trait(IgnoreKey("w"))
trait(IgnoreKey("z"))
field(W_DESCRIPTOR)
field(X_DESCRIPTOR)
field(Y_DESCRIPTOR)
field(Z_DESCRIPTOR)
}

@Test
fun itIgnoresMultipleKeys() {
val payload = """
{
"w": 0,
"x": 1,
"y": 2,
"z": 3
}
""".trimIndent().encodeToByteArray()

val deserializer = JsonDeserializer(payload)
var w: Int? = null
var x: Int? = null
var y: Int? = null
var z: Int? = null
deserializer.deserializeStruct(MULT_KEYS_OBJ_DESCRIPTOR) {
loop@ while (true) {
when (findNextFieldIndex()) {
W_DESCRIPTOR.index -> w = deserializeInt()
X_DESCRIPTOR.index -> x = deserializeInt()
Y_DESCRIPTOR.index -> y = deserializeInt()
Z_DESCRIPTOR.index -> z = deserializeInt()
null -> break@loop
}
}
}

assertEquals(null, w)
assertEquals(1, x)
assertEquals(2, y)
assertEquals(null, z)
}
}
Loading