Skip to content

Commit

Permalink
add plugin setting to control default value serialization and revert …
Browse files Browse the repository at this point in the history
…to only serializing when value differs from runtime
  • Loading branch information
aajtodd committed Oct 3, 2023
1 parent 15ba788 commit 3b455aa
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ fun codegenSerializerForShape(
model: Model,
shapeId: String,
location: HttpBinding.Location = HttpBinding.Location.DOCUMENT,
settings: KotlinSettings? = null,
): String {
val ctx = model.newTestContext()
val resolvedSettings = settings ?: model.defaultSettings(TestModelDefault.SERVICE_NAME, TestModelDefault.NAMESPACE)
val ctx = model.newTestContext(settings = resolvedSettings)

val op = ctx.generationCtx.model.expectShape(ShapeId.from(shapeId))
return testRender(ctx.requestMembers(op, location)) { members, writer ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@ package software.amazon.smithy.kotlin.codegen.test

import software.amazon.smithy.build.MockManifest
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.*
import software.amazon.smithy.kotlin.codegen.core.CodegenContext
import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator
import software.amazon.smithy.kotlin.codegen.inferService
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
import software.amazon.smithy.kotlin.codegen.model.OperationNormalizer
import software.amazon.smithy.kotlin.codegen.model.shapes
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
Expand Down Expand Up @@ -121,7 +120,7 @@ fun Model.newTestContext(
integrations: List<KotlinIntegration> = listOf(),
): TestContext {
val manifest = MockManifest()
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model = this, rootNamespace = packageName, serviceName = serviceName)
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model = this, rootNamespace = packageName, serviceName = serviceName, settings = settings)
val service = this.getShape(ShapeId.from("$packageName#$serviceName")).get().asServiceShape().get()
val delegator = KotlinDelegator(settings, this, manifest, provider)

Expand Down Expand Up @@ -173,6 +172,8 @@ fun Model.defaultSettings(
packageVersion: String = TestModelDefault.MODEL_VERSION,
sdkId: String = TestModelDefault.SDK_ID,
generateDefaultBuildFiles: Boolean = false,
nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL,
defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT,
): KotlinSettings {
val serviceId = if (serviceName == null) {
this.inferService()
Expand All @@ -197,6 +198,12 @@ fun Model.defaultSettings(
Node.objectNode()
.withMember("generateDefaultBuildFiles", Node.from(generateDefaultBuildFiles)),
)
.withMember(
"api",
Node.objectNode()
.withMember(ApiSettings.NULLABILITY_CHECK_MODE, Node.from(nullabilityCheckMode.kotlinPluginSetting))
.withMember(ApiSettings.DEFAULT_VALUE_SERIALIZATION_MODE, Node.from(defaultValueSerializationMode.value)),
)
.build(),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,59 @@ private fun checkModefromValue(value: String): CheckMode {
return requireNotNull(camelCaseToMode[value]) { "$value is not a valid CheckMode, expected one of ${camelCaseToMode.keys}" }
}

/**
* Get the plugin setting for this check mode
*/
val CheckMode.kotlinPluginSetting: String
get() = toString().toCamelCase()

enum class DefaultValueSerializationMode(val value: String) {
/**
* Always serialize default values even if they are set to the default
*/
ALWAYS("always"),

/**
* Only serialize default values when they differ from the default given in the model.
*/
WHEN_DIFFERENT("whenDifferent"),
;
override fun toString(): String = value
companion object {
fun fromValue(value: String): DefaultValueSerializationMode =
values().find {
it.value == value
} ?: throw IllegalArgumentException("$value is not a valid DefaultValueSerializationMode, expected one of ${values().map { it.value }}")
}
}

/**
* Contains API settings for a Kotlin project
* @param visibility Enum representing the visibility of code-generated classes, objects, interfaces, etc.
* @param nullabilityCheckMode Enum representing the nullability check mode to use
* @param defaultValueSerializationMode Enum representing when default values should be serialized
*/
data class ApiSettings(
val visibility: Visibility = Visibility.PUBLIC,
val nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL,
val defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT,
) {
companion object {
const val VISIBILITY = "visibility"
const val NULLABILITY_CHECK_MODE = "nullabilityCheckMode"
const val DEFAULT_VALUE_SERIALIZATION_MODE = "defaultValueSerializationMode"

fun fromNode(node: Optional<ObjectNode>): ApiSettings = node.map {
val visibility = Visibility.fromValue(node.get().getStringMemberOrDefault(VISIBILITY, "public"))
val checkMode = checkModefromValue(node.get().getStringMemberOrDefault(NULLABILITY_CHECK_MODE, "clientCareful"))
ApiSettings(visibility, checkMode)
val defaultValueSerializationMode = DefaultValueSerializationMode.fromValue(
node.get()
.getStringMemberOrDefault(
DEFAULT_VALUE_SERIALIZATION_MODE,
DefaultValueSerializationMode.WHEN_DIFFERENT.value,
),
)
ApiSettings(visibility, checkMode, defaultValueSerializationMode)
}.orElse(Default)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package software.amazon.smithy.kotlin.codegen.rendering.protocol

import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.kotlin.codegen.DefaultValueSerializationMode
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.model.*
Expand Down Expand Up @@ -34,6 +36,7 @@ import software.amazon.smithy.utils.AbstractCodeWriter
class HttpStringValuesMapSerializer(
private val model: Model,
private val symbolProvider: SymbolProvider,
private val settings: KotlinSettings,
private val bindings: List<HttpBindingDescriptor>,
private val resolver: HttpBindingResolver,
private val defaultTimestampFormat: TimestampFormatTrait.Format,
Expand All @@ -43,7 +46,7 @@ class HttpStringValuesMapSerializer(
bindings: List<HttpBindingDescriptor>,
resolver: HttpBindingResolver,
defaultTimestampFormat: TimestampFormatTrait.Format,
) : this(ctx.model, ctx.symbolProvider, bindings, resolver, defaultTimestampFormat)
) : this(ctx.model, ctx.symbolProvider, ctx.settings, bindings, resolver, defaultTimestampFormat)

fun render(
writer: KotlinWriter,
Expand Down Expand Up @@ -77,17 +80,35 @@ class HttpStringValuesMapSerializer(
is StringShape -> renderStringShape(it, memberTarget, writer)
is IntEnumShape -> {
val appendFn = writer.format("append(#S, \"\${input.#L.value}\")", paramName, memberName)
writer.writeWithCondIfCheck(memberSymbol.isNullable, "input.$memberName != null", appendFn)
if (memberSymbol.isNullable) {
writer.write("if (input.$memberName != null) $appendFn")
} else {
val defaultCheck = defaultCheck(member) ?: ""
writer.writeWithCondIfCheck(defaultCheck.isNotEmpty(), defaultCheck, appendFn)
}
}
else -> {
// encode to string
val encodedValue = "\"\${input.$memberName}\""
val appendFn = writer.format("append(#S, #L)", paramName, encodedValue)
writer.writeWithCondIfCheck(memberSymbol.isNullable, "input.$memberName != null", appendFn)
if (memberSymbol.isNullable) {
writer.write("if (input.$memberName != null) $appendFn")
} else {
val defaultCheck = defaultCheck(member) ?: ""
writer.writeWithCondIfCheck(defaultCheck.isNotEmpty(), defaultCheck, appendFn)
}
}
}
}
}
private fun defaultCheck(member: MemberShape): String? {
val memberSymbol = symbolProvider.toSymbol(member)
val memberName = symbolProvider.toMemberName(member)
val defaultValue = memberSymbol.defaultValue()
val checkDefaults = settings.api.defaultValueSerializationMode == DefaultValueSerializationMode.WHEN_DIFFERENT
val check = "input.$memberName != $defaultValue"
return check.takeIf { checkDefaults && !member.isRequired && memberSymbol.isNotNullable && defaultValue != null }
}

private fun AbstractCodeWriter<*>.writeWithCondIfCheck(cond: Boolean, check: String, body: String) {
if (cond) {
Expand Down Expand Up @@ -152,7 +173,7 @@ class HttpStringValuesMapSerializer(
writer.addImport(RuntimeTypes.SmithyClient.IdempotencyTokenProviderExt)
writer.write("append(#S, (input.$memberName ?: context.idempotencyTokenProvider.generateToken()))", paramName)
} else {
val cond =
val nullCheck =
if (location == HttpBinding.Location.QUERY ||
memberTarget.hasTrait<@Suppress("DEPRECATION") software.amazon.smithy.model.traits.EnumTrait>()
) {
Expand All @@ -162,6 +183,8 @@ class HttpStringValuesMapSerializer(
"input.$memberName$nullCheck.isNotEmpty() == true"
}

val cond = defaultCheck(binding.member) ?: nullCheck

val suffix = when {
memberTarget.hasTrait<@Suppress("DEPRECATION") software.amazon.smithy.model.traits.EnumTrait>() -> {
".value"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package software.amazon.smithy.kotlin.codegen.rendering.serde

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.kotlin.codegen.DefaultValueSerializationMode
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.model.targetOrSelf
Expand Down Expand Up @@ -567,14 +568,21 @@ open class SerializeStructGenerator(
val postfix = idempotencyTokenPostfix(memberShape)
val memberSymbol = ctx.symbolProvider.toSymbol(memberShape)
val memberName = ctx.symbolProvider.toMemberName(memberShape)

if (memberSymbol.isNullable) {
val identifier = valueToSerializeName("it")
val fn = serializerFn.format(memberShape, identifier)
writer.write("input.$memberName?.let { $fn }$postfix")
} else {
// always serialize required members, otherwise check if it's a primitive type set to it's default before serializing
val defaultValue = memberSymbol.defaultValue()
val checkDefaults = ctx.settings.api.defaultValueSerializationMode == DefaultValueSerializationMode.WHEN_DIFFERENT
val defaultCheck = if (checkDefaults && !memberShape.isRequired && memberSymbol.isNotNullable && defaultValue != null) {
"if (input.$memberName != $defaultValue) "
} else {
""
}
val fn = serializerFn.format(memberShape, "input.$memberName")
writer.write("$fn$postfix")
writer.write("${defaultCheck}${fn}$postfix")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,21 @@ class KotlinSettingsTest {

assertEquals(expected, apiSettings.nullabilityCheckMode)
}

@ParameterizedTest(name = "{0} ==> {1}")
@CsvSource(
"always, ALWAYS",
"whenDifferent, WHEN_DIFFERENT",
)
fun testDefaultValueSerializationMode(pluginSetting: String, expectedEnumString: String) {
val expected = DefaultValueSerializationMode.valueOf(expectedEnumString)
val contents = """
{
"defaultValueSerializationMode": "$pluginSetting"
}
""".trimIndent()
val apiSettings = ApiSettings.fromNode(Node.parse(contents).asObjectNode())

assertEquals(expected, apiSettings.defaultValueSerializationMode)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

package software.amazon.smithy.kotlin.codegen.rendering.protocol

import software.amazon.smithy.kotlin.codegen.DefaultValueSerializationMode
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.loadModelFromResource
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.test.assertBalancedBracesAndParens
import software.amazon.smithy.kotlin.codegen.test.newTestContext
import software.amazon.smithy.kotlin.codegen.test.shouldContainOnlyOnceWithDiff
import software.amazon.smithy.kotlin.codegen.test.*
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.shapes.OperationShape
Expand All @@ -20,8 +20,9 @@ import kotlin.test.Test
class HttpStringValuesMapSerializerTest {
private val defaultModel = loadModelFromResource("http-binding-protocol-generator-test.smithy")

private fun getTestContents(model: Model, operationId: String, location: HttpBinding.Location): String {
val testCtx = model.newTestContext()
private fun getTestContents(model: Model, operationId: String, location: HttpBinding.Location, settings: KotlinSettings? = null): String {
val resolvedSettings = settings ?: model.defaultSettings(TestModelDefault.SERVICE_NAME, TestModelDefault.NAMESPACE)
val testCtx = model.newTestContext(settings = resolvedSettings)
val httpGenerator = testCtx.generator as HttpBindingProtocolGenerator
val resolver = httpGenerator.getProtocolHttpBindingResolver(testCtx.generationCtx.model, testCtx.generationCtx.service)
val op = model.expectShape<OperationShape>(operationId)
Expand All @@ -35,8 +36,9 @@ class HttpStringValuesMapSerializerTest {
}

@Test
fun `it handles primitive header shapes`() {
val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.HEADER)
fun `it handles primitive header shapes always mode`() {
val settings = defaultModel.defaultSettings(defaultValueSerializationMode = DefaultValueSerializationMode.ALWAYS)
val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.HEADER, settings)
contents.assertBalancedBracesAndParens()

val expectedContents = """
Expand All @@ -50,10 +52,37 @@ class HttpStringValuesMapSerializerTest {
}

@Test
fun `it handles primitive query shapes`() {
fun `it handles primitive header shapes when different mode`() {
val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.HEADER)
contents.assertBalancedBracesAndParens()

val expectedContents = """
if (input.hBool != false) append("X-d", "${'$'}{input.hBool}")
if (input.hFloat != 0f) append("X-c", "${'$'}{input.hFloat}")
if (input.hInt != 0) append("X-a", "${'$'}{input.hInt}")
if (input.hLong != 0L) append("X-b", "${'$'}{input.hLong}")
append("X-required", "${'$'}{input.hRequiredInt}")
""".trimIndent()
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}

@Test
fun `it handles primitive query shapes when different mode`() {
val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.QUERY)
contents.assertBalancedBracesAndParens()

val expectedContents = """
if (input.qInt != 0) append("q-int", "${'$'}{input.qInt}")
""".trimIndent()
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}

@Test
fun `it handles primitive query shapes always mode`() {
val settings = defaultModel.defaultSettings(defaultValueSerializationMode = DefaultValueSerializationMode.ALWAYS)
val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.QUERY, settings)
contents.assertBalancedBracesAndParens()

val expectedContents = """
append("q-int", "${'$'}{input.qInt}")
""".trimIndent()
Expand All @@ -71,6 +100,46 @@ class HttpStringValuesMapSerializerTest {
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}

@Test
fun `it handles enum default value when different mode`() {
val model = """
@http(method: "POST", uri: "/foo")
operation Foo {
input: FooRequest
}
enum MyEnum {
Variant1,
Variant2
}
intEnum MyIntEnum {
Tay = 1
Lep = 2
}
structure FooRequest {
@default("Variant1")
@httpHeader("X-EnumHeader")
enumHeader: MyEnum
@default(2)
@httpHeader("X-IntEnumHeader")
intEnumHeader: MyIntEnum
}
""".prependNamespaceAndService(operations = listOf("Foo")).toSmithyModel()

val contents = getTestContents(model, "com.test#Foo", HttpBinding.Location.HEADER)
contents.assertBalancedBracesAndParens()

val intEnumValue = "\${input.intEnumHeader.value}"
val expectedContents = """
if (input.enumHeader != com.test.model.MyEnum.fromValue("Variant1")) append("X-EnumHeader", input.enumHeader.value)
if (input.intEnumHeader != com.test.model.MyIntEnum.fromValue(2)) append("X-IntEnumHeader", "$intEnumValue")
""".trimIndent()
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}

@Test
fun `it handles string shapes`() {
val contents = getTestContents(defaultModel, "com.test#SmokeTest", HttpBinding.Location.HEADER)
Expand Down
Loading

0 comments on commit 3b455aa

Please sign in to comment.