Skip to content

Commit

Permalink
Added documentation and modified code-gen logic
Browse files Browse the repository at this point in the history
  • Loading branch information
0marperez committed Sep 27, 2023
1 parent 422c33e commit 8861cc6
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import software.amazon.smithy.kotlin.codegen.rendering.serde.JsonSerdeDescriptor
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape

/**
* Overrides the [JsonParserGenerator] when using `AWS Json 1.0`, `AWS Json 1.1`, and `RestJson 1.0` protocols.
*
* See https://github.com/smithy-lang/smithy/pull/1945
*/
class AwsJsonProtocolParserGenerator(
private val protocolGenerator: ProtocolGenerator,
private val supportsJsonNameTrait: Boolean = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape

/**
* Overrides the [JsonSerdeDescriptorGenerator] when using `AWS Json 1.0`, `AWS Json 1.1`, and `RestJson 1.0` protocols.
*
* See: https://github.com/smithy-lang/smithy/pull/1945
*/
class AwsJsonProtocolSerdeDescriptorGenerator(
ctx: RenderingContext<Shape>,
memberShapes: List<MemberShape>? = null,
supportsJsonNameTrait: Boolean = true,
) : JsonSerdeDescriptorGenerator(ctx, memberShapes, supportsJsonNameTrait) {

/**
* Adds a trait to ignore `__type` in union shapes for AWS specific JSON protocols
* Adds a trait to ignore `__type` in union shapes for `AWS Json 1.0`, `AWS Json 1.1`, `RestJson 1.0` protocols
* Sometimes the unnecessary trait `__type` is added and needs to be ignored
*
* NOTE: Will be ignored unless it's in the model
Expand All @@ -30,7 +35,12 @@ class AwsJsonProtocolSerdeDescriptorGenerator(
*/
override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
val traitList = super.getObjectDescriptorTraits().toMutableList()
if (ctx.shape?.isUnionShape == true) traitList.add(RuntimeTypes.Serde.IgnoreKey, "__type".dq(), "false")
val typeMember = memberShapes.find { it.memberName == "__type" }

if (ctx.shape?.isUnionShape == true && typeMember == null) {
traitList.add(RuntimeTypes.Serde.SerdeJson.IgnoreKey, "__type".dq())
}

return traitList
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import kotlin.test.Test

class AwsJsonProtocolSerdeDescriptorGeneratorTest {
@Test
fun itHandlesUnionsAndAddsIgnoreKeysTrait() {
fun itAddsIgnoreKeysTrait() {
val model = """
@http(method: "POST", uri: "/foo")
operation Foo {
Expand Down Expand Up @@ -40,7 +40,7 @@ class AwsJsonProtocolSerdeDescriptorGeneratorTest {
val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, JsonSerialName("x"))
val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, JsonSerialName("y"))
val OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
trait(IgnoreKey("__type", false))
trait(IgnoreKey("__type"))
field(X_DESCRIPTOR)
field(Y_DESCRIPTOR)
}
Expand All @@ -49,4 +49,43 @@ class AwsJsonProtocolSerdeDescriptorGeneratorTest {
val contents = writer.toString()
contents.shouldContainOnlyOnceWithDiff(expectedDescriptors)
}

@Test
fun itDoesNotAddIgnoreKeysTrait() {
val model = """
@http(method: "POST", uri: "/foo")
operation Foo {
input: FooRequest
}
structure FooRequest {
strVal: String,
intVal: Integer
}
union Bar {
__type: String,
y: String,
}
""".prependNamespaceAndService(operations = listOf("Foo")).toSmithyModel()

val testCtx = model.newTestContext()
val writer = testCtx.newWriter()
val shape = model.expectShape(ShapeId.from("com.test#Bar"))
val renderingCtx = testCtx.toRenderingContext(writer, shape)

AwsJsonProtocolSerdeDescriptorGenerator(renderingCtx).render()

val expectedDescriptors = """
val TYPE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, JsonSerialName("__type"))
val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, JsonSerialName("y"))
val OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
field(TYPE_DESCRIPTOR)
field(Y_DESCRIPTOR)
}
""".formatForTest("")

val contents = writer.toString()
contents.shouldContainOnlyOnceWithDiff(expectedDescriptors)
}
}

0 comments on commit 8861cc6

Please sign in to comment.