Skip to content

Commit

Permalink
wip nullability tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd committed Sep 27, 2023
1 parent 23edf6a commit 059a45c
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import software.amazon.smithy.build.MockManifest
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.expectShape
Expand Down Expand Up @@ -250,9 +251,15 @@ fun generateCode(generator: (KotlinWriter) -> Unit): String {
return rawCodegen.substring(rawCodegen.indexOf(packageDeclaration) + packageDeclaration.length).trim()
}

fun KotlinCodegenPlugin.Companion.createSymbolProvider(model: Model, rootNamespace: String = TestModelDefault.NAMESPACE, sdkId: String = TestModelDefault.SDK_ID, serviceName: String = TestModelDefault.SERVICE_NAME): SymbolProvider {
val settings = model.defaultSettings(serviceName = serviceName, packageName = rootNamespace, sdkId = sdkId)
return createSymbolProvider(model, settings)
fun KotlinCodegenPlugin.Companion.createSymbolProvider(
model: Model,
rootNamespace: String = TestModelDefault.NAMESPACE,
sdkId: String = TestModelDefault.SDK_ID,
serviceName: String = TestModelDefault.SERVICE_NAME,
settings: KotlinSettings? = null,
): SymbolProvider {
val resolvedSettings = settings ?: model.defaultSettings(serviceName = serviceName, packageName = rootNamespace, sdkId = sdkId)
return createSymbolProvider(model, resolvedSettings)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.ClientOptionalTrait
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.SparseTrait
import software.amazon.smithy.model.traits.StreamingTrait
import java.util.logging.Logger

Expand Down Expand Up @@ -93,12 +92,12 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
override fun stringShape(shape: StringShape): Symbol = if (shape.isEnum) {
createEnumSymbol(shape)
} else {
createSymbolBuilder(shape, "String", nullable = true, namespace = "kotlin").build()
createSymbolBuilder(shape, "String", namespace = "kotlin").build()
}

private fun createEnumSymbol(shape: Shape): Symbol {
val namespace = "$rootNamespace.model"
return createSymbolBuilder(shape, shape.defaultName(service), namespace, nullable = true)
return createSymbolBuilder(shape, shape.defaultName(service), namespace)
.definitionFile("${shape.defaultName(service)}.kt")
.build()
}
Expand All @@ -109,7 +108,7 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
override fun structureShape(shape: StructureShape): Symbol {
val name = shape.defaultName(service)
val namespace = "$rootNamespace.model"
val builder = createSymbolBuilder(shape, name, namespace, nullable = true)
val builder = createSymbolBuilder(shape, name, namespace)
.definitionFile("$name.kt")

// add a reference to each member symbol
Expand Down Expand Up @@ -147,10 +146,10 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli

override fun listShape(shape: ListShape): Symbol {
val reference = toSymbol(shape.member)
val valueSuffix = if (shape.hasTrait<SparseTrait>()) "?" else ""
val valueSuffix = if (reference.isNullable) "?" else ""
val valueType = "${reference.name}$valueSuffix"
val fullyQualifiedValueType = "${reference.fullName}$valueSuffix"
return createSymbolBuilder(shape, "List<$valueType>", nullable = true)
return createSymbolBuilder(shape, "List<$valueType>")
.addReferences(reference)
.putProperty(SymbolProperty.FULLY_QUALIFIED_NAME_HINT, "List<$fullyQualifiedValueType>")
.putProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION, "mutableListOf<$valueType>")
Expand All @@ -160,13 +159,13 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli

override fun mapShape(shape: MapShape): Symbol {
val reference = toSymbol(shape.value)
val valueSuffix = if (shape.hasTrait<SparseTrait>()) "?" else ""
val valueSuffix = if (reference.isNullable) "?" else ""
val valueType = "${reference.name}$valueSuffix"
val fullyQualifiedValueType = "${reference.fullName}$valueSuffix"

val keyType = KotlinTypes.String.name
val fullyQualifiedKeyType = KotlinTypes.String.fullName
return createSymbolBuilder(shape, "Map<$keyType, $valueType>", nullable = true)
return createSymbolBuilder(shape, "Map<$keyType, $valueType>")
.addReferences(reference)
.putProperty(SymbolProperty.FULLY_QUALIFIED_NAME_HINT, "Map<$fullyQualifiedKeyType, $fullyQualifiedValueType>")
.putProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION, "mutableMapOf<$keyType, $valueType>")
Expand All @@ -182,7 +181,7 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
val targetSymbol = toSymbol(targetShape)
.toBuilder()
.apply {
if (nullableIndex.isMemberNullable(shape, NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1_NO_INPUT)) nullable()
if (nullableIndex.isMemberNullable(shape, settings.api.nullabilityCheckMode)) nullable()

if (!shape.hasTrait<ClientOptionalTrait>()) { // @ClientOptional supersedes @default
shape.getTrait<DefaultTrait>()?.let {
Expand Down Expand Up @@ -238,16 +237,16 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli

override fun timestampShape(shape: TimestampShape?): Symbol {
val dependency = KotlinDependency.CORE
return createSymbolBuilder(shape, "Instant", nullable = true)
return createSymbolBuilder(shape, "Instant")
.namespace("${dependency.namespace}.time", ".")
.addDependency(dependency)
.build()
}

override fun blobShape(shape: BlobShape): Symbol = if (shape.hasTrait<StreamingTrait>()) {
RuntimeTypes.Core.Content.ByteStream.asNullable()
RuntimeTypes.Core.Content.ByteStream
} else {
createSymbolBuilder(shape, "ByteArray", nullable = true, namespace = "kotlin").build()
createSymbolBuilder(shape, "ByteArray", namespace = "kotlin").build()
}

override fun documentShape(shape: DocumentShape?): Symbol =
Expand All @@ -256,7 +255,7 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
override fun unionShape(shape: UnionShape): Symbol {
val name = shape.defaultName(service)
val namespace = "$rootNamespace.model"
val builder = createSymbolBuilder(shape, name, namespace, nullable = true)
val builder = createSymbolBuilder(shape, name, namespace)
.definitionFile("$name.kt")

// add a reference to each member symbol
Expand All @@ -265,7 +264,10 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
return builder.build()
}

override fun resourceShape(shape: ResourceShape?): Symbol = createSymbolBuilder(shape, "Resource").build()
override fun resourceShape(shape: ResourceShape?): Symbol {
// The Kotlin SDK does not produce code explicitly based on Resources
error { "unexpected codegen code path" }
}

override fun operationShape(shape: OperationShape?): Symbol {
// The Kotlin SDK does not produce code explicitly based on Operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class StructureGenerator(
memberName,
memberSymbol,
)
// FIXME - shouldn't this only target string shapes?
if (memberShape.isNonBlankInStruct) {
writer
.indent()
Expand Down Expand Up @@ -242,7 +243,6 @@ class StructureGenerator(
.withBlock("public class Builder {", "}") {
for (member in sortedMembers) {
val (memberName, memberSymbol) = memberNameSymbolIndex[member]!!
// we want the type names sans nullability (?) for arguments
writer.renderMemberDocumentation(model, member)
writer.renderAnnotations(member)
write("public var #L: #E", memberName, memberSymbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
import org.junit.jupiter.params.provider.ValueSource
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.kotlin.codegen.ApiSettings
import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
import software.amazon.smithy.kotlin.codegen.core.KotlinDependency.Companion.CORE
import software.amazon.smithy.kotlin.codegen.model.defaultValue
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.model.fullNameHint
import software.amazon.smithy.kotlin.codegen.model.isNullable
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.model.traits.SYNTHETIC_NAMESPACE
import software.amazon.smithy.kotlin.codegen.test.*
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.ClientOptionalTrait
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue

class SymbolProviderTest {
Expand Down Expand Up @@ -775,4 +776,150 @@ class SymbolProviderTest {

assertEquals("com.test.model.Events", symbol.references[0].symbol.fullName)
}

@ParameterizedTest(name = "{index} ==> ''{0}''")
@ValueSource(strings = ["CLIENT_CAREFUL", "CLIENT"])
fun `it handles client nullability for IDL v2 check modes`(rawCheckMode: String) {
val model = """
operation TestOp {
input: OpInput
}
@input
structure OpInput {
@httpHeader("x-test")
xTestHeader: String,
@required
boolean: Boolean
list: IntList,
map: StringMap,
@required
top: MyStruct
}
integer MyInt
@default("foo")
string MyString
list IntList {
member: MyInt
}
@sparse
list SparseIntList {
member: MyInt
}
map StringMap {
key: String,
value: MyInt
}
structure MyStruct {
@required
union: MyUnion,
@required
string: String,
@required
list: IntList,
sparseList: SparseIntList,
@required
nested: Nested,
@required
@clientOptional
clientOptionalString: String,
@required
enum: MyEnum,
@default(1)
defaultInt: MyInt,
@default("foo")
defaultString: MyString,
@default(null)
defaultButNullString: MyString
}
enum MyEnum {
Variant1,
Variant2
}
structure Nested {
nestedString: String
}
union MyUnion {
blob: Blob,
boolean: Boolean,
date: Timestamp,
int: Integer,
}
""".prependNamespaceAndService(operations = listOf("TestOp")).toSmithyModel()

val checkMode = CheckMode.valueOf(rawCheckMode)
val settings = model.defaultSettings().copy(api = ApiSettings(nullabilityCheckMode = checkMode))
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model, settings = settings)

// opInput members always optional because of @input
val opInputStruct = model.expectShape<StructureShape>("com.test#OpInput")
opInputStruct.members().forEach {
assertTrue(provider.toSymbol(it).isNullable, "expected $it to be nullable because its marked with @input trait")
}

// struct/union members optional in client careful
val myStruct = model.expectShape<StructureShape>("com.test#MyStruct")
val unionAndStructMembers = listOf("union", "nested").map { myStruct.getMember(it).get() }
unionAndStructMembers.forEach {
val memberSymbol = provider.toSymbol(it)
when (checkMode) {
CheckMode.CLIENT_CAREFUL -> assertTrue(memberSymbol.isNullable, "struct/union $it should be optional in $checkMode mode")
else -> assertFalse(memberSymbol.isNullable, "struct/union $it should be required in $checkMode")
}
}

// required members not optional - except client careful
myStruct.members()
.filter(MemberShape::isRequired)
.filterNot { it in unionAndStructMembers }
.forEach {
val memberSymbol = provider.toSymbol(it)
if (it.hasTrait<ClientOptionalTrait>()) {
assertTrue(memberSymbol.isNullable, "@clientOptional member $it should be nullable regardless of @required")
} else {
assertFalse(memberSymbol.isNullable, "@required member $it should not be nullable")
}
}

// union members are not optional
val unionShape = model.expectShape<UnionShape>("com.test#MyUnion")
assertTrue(unionShape.members().map { provider.toSymbol(it) }.none { it.isNullable }, "union members should not be nullable")

// default null are optional
myStruct.members()
.filter { it.hasNullDefault() }
.forEach {
val memberSymbol = provider.toSymbol(it)
assertTrue(memberSymbol.isNullable, "member $it with explicit null default should be nullable")
}

// non-null default are not optional
myStruct.members()
.filter { it.hasNonNullDefault() }
.forEach {
val memberSymbol = provider.toSymbol(it)
assertFalse(memberSymbol.isNullable, "member $it with non-null default should not be nullable")
}
}
}

0 comments on commit 059a45c

Please sign in to comment.