Skip to content

Commit

Permalink
fix: correctly deserialize/render input values for smoke tests (#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianbotsf authored Nov 20, 2024
1 parent 1d3028e commit 9e1a8d6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ShapeValueGenerator(
private fun classDeclaration(writer: KotlinWriter, shape: StructureShape, block: () -> Unit) {
val symbol = symbolProvider.toSymbol(shape)
// invoke the generated DSL builder for the class
writer.writeInline("#L {", symbol.name)
writer.writeInline("#T {", symbol)
.ensureNewline()
.indent()
.call { block() }
Expand Down Expand Up @@ -129,7 +129,7 @@ class ShapeValueGenerator(
val suffix = when {
shape.isEnum -> {
val symbol = symbolProvider.toSymbol(shape)
writer.writeInline("#L.fromValue(", symbol.name)
writer.writeInline("#T.fromValue(", symbol)
")"
}

Expand Down Expand Up @@ -222,7 +222,7 @@ class ShapeValueGenerator(
val currSymbol = generator.symbolProvider.toSymbol(currShape)
val memberName = generator.symbolProvider.toMemberName(member)
val variantName = memberName.replaceFirstChar { c -> c.uppercaseChar() }
writer.writeInline("${currSymbol.name}.$variantName(")
writer.writeInline("#T.#L(", currSymbol, variantName)
generator.instantiateShapeInline(writer, memberShape, valueNode)
writer.writeInline(")")
}
Expand All @@ -243,14 +243,14 @@ class ShapeValueGenerator(
ShapeType.DOUBLE,
ShapeType.FLOAT,
-> {
val symbolName = generator.symbolProvider.toSymbol(currShape).name
val symbol = generator.symbolProvider.toSymbol(currShape)
val symbolMember = when (node.value) {
"Infinity" -> "POSITIVE_INFINITY"
"-Infinity" -> "NEGATIVE_INFINITY"
"NaN" -> "NaN"
else -> throw CodegenException("""Cannot interpret $symbolName value "${node.value}".""")
else -> throw CodegenException("""Cannot interpret $symbol value "${node.value}".""")
}
writer.writeInline("#L", "$symbolName.$symbolMember")
writer.writeInline("#T.#L", symbol, symbolMember)
}

ShapeType.BIG_INTEGER -> writer.writeInline("#T(#S)", RuntimeTypes.Core.Content.BigInteger, node.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@ import software.amazon.smithy.kotlin.codegen.integration.SectionId
import software.amazon.smithy.kotlin.codegen.integration.SectionKey
import software.amazon.smithy.kotlin.codegen.model.getTrait
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.isStringEnumShape
import software.amazon.smithy.kotlin.codegen.rendering.ShapeValueGenerator
import software.amazon.smithy.kotlin.codegen.rendering.endpoints.EndpointParametersGenerator
import software.amazon.smithy.kotlin.codegen.rendering.endpoints.EndpointProviderGenerator
import software.amazon.smithy.kotlin.codegen.rendering.protocol.stringToNumber
import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.EndpointParams
import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.EndpointProvider
import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.Name
import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.Value
import software.amazon.smithy.kotlin.codegen.rendering.util.format
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.kotlin.codegen.utils.toCamelCase
import software.amazon.smithy.kotlin.codegen.utils.toPascalCase
import software.amazon.smithy.kotlin.codegen.utils.topDownOperations
import software.amazon.smithy.model.node.*
import software.amazon.smithy.model.shapes.*
Expand Down Expand Up @@ -134,18 +132,19 @@ class SmokeTestsRunnerGenerator(
}
write("")
withInlineBlock("try {", "} ") {
renderClient(testCase)
renderOperation(operation, testCase)
renderTestCase(operation, testCase)
}
withBlock("catch (exception: Exception) {", "}") {
renderCatchBlock(testCase)
}
}
}

private fun renderClient(testCase: SmokeTestCase) {
writer.withInlineBlock("#L {", "}", service) {
private fun renderTestCase(operation: OperationShape, testCase: SmokeTestCase) {
writer.withBlock("#T {", "}", service) {
renderClientConfig(testCase)
closeAndOpenBlock("}.#T { client ->", RuntimeTypes.Core.IO.use)
renderOperation(operation, testCase)
}
}

Expand All @@ -161,10 +160,8 @@ class SmokeTestsRunnerGenerator(
return
}

testCase.clientConfig!!.forEach { config ->
val name = config.key.value.toCamelCase()
val value = config.value.format()

testCase.clientConfig.forEach { (name, unformattedValue) ->
val value = unformattedValue.format()
writer.declareSection(
SmokeTestSectionIds.ClientConfig,
mapOf(
Expand All @@ -174,33 +171,23 @@ class SmokeTestsRunnerGenerator(
EndpointParams to EndpointParametersGenerator.getSymbol(settings),
),
) {
writer.writeInline("#L = #L", name, value)
writer.write("#L = #L", name, value)
}
}
}

private fun renderOperation(operation: OperationShape, testCase: SmokeTestCase) {
val operationSymbol = symbolProvider.toSymbol(model.getShape(operation.input.get()).get())

writer.withBlock(".#T { client ->", "}", RuntimeTypes.Core.IO.use) {
withBlock("client.#L(", ")", operation.defaultName()) {
withBlock("#L {", "}", operationSymbol) {
renderOperationParameters(operation, testCase)
}
}
}
}

private fun renderOperationParameters(operation: OperationShape, testCase: SmokeTestCase) {
if (!testCase.hasOperationParameters) return
val inputParams = testCase.params.getOrNull()

val paramsToShapes = mapOperationParametersToModeledShapes(operation)
writer.writeInline("client.#L", operation.defaultName())

testCase.operationParameters.forEach { param ->
val paramName = param.key.value.toCamelCase()
writer.writeInline("#L = ", paramName)
val paramShape = paramsToShapes[paramName] ?: throw IllegalArgumentException("Unable to find shape for operation parameter '$paramName' in smoke test '${testCase.functionName}'.")
renderOperationParameter(paramName, param.value, paramShape, testCase)
if (inputParams == null) {
writer.write("()")
} else {
writer.withBlock("(", ")") {
val inputShape = model.expectShape(operation.input.get())
ShapeValueGenerator(model, symbolProvider).instantiateShapeInline(writer, inputShape, inputParams)
}
}
}

Expand Down Expand Up @@ -228,56 +215,6 @@ class SmokeTestsRunnerGenerator(
}

// Helpers
/**
* Renders a [SmokeTestCase] operation parameter
*/
private fun renderOperationParameter(
paramName: String,
node: Node,
shape: Shape,
testCase: SmokeTestCase,
) {
when {
// String enum
node is StringNode && shape.isStringEnumShape -> {
val enumSymbol = symbolProvider.toSymbol(shape)
val enumValue = node.value.toPascalCase()
writer.write("#T.#L", enumSymbol, enumValue)
}
// Int enum
node is NumberNode && shape is IntEnumShape -> {
val enumSymbol = symbolProvider.toSymbol(shape)
val enumValue = node.format()
writer.write("#T.fromValue(#L.toInt())", enumSymbol, enumValue)
}
// Number
node is NumberNode && shape is NumberShape -> writer.write("#L.#L", node.format(), stringToNumber(shape))
// Object
node is ObjectNode -> {
val shapeSymbol = symbolProvider.toSymbol(shape)
writer.withBlock("#T {", "}", shapeSymbol) {
node.members.forEach { member ->
val memberName = member.key.value.toCamelCase()
val memberShape = shape.allMembers[member.key.value] ?: throw IllegalArgumentException("Unable to find shape for operation parameter '$paramName' in smoke test '${testCase.functionName}'.")
writer.writeInline("#L = ", memberName)
renderOperationParameter(memberName, member.value, memberShape, testCase)
}
}
}
// List
node is ArrayNode && shape is CollectionShape -> {
writer.withBlock("listOf(", ")") {
node.elements.forEach { element ->
renderOperationParameter(paramName, element, model.expectShape(shape.member.target), testCase)
writer.write(",")
}
}
}
// Everything else
else -> writer.write("#L", node.format())
}
}

/**
* Tries to get the specific exception required in the failure criterion of a test.
* If no specific exception is required we default to the generic smoke tests failure exception.
Expand All @@ -304,14 +241,6 @@ class SmokeTestsRunnerGenerator(
writer.write("println(#S)", testResult)
}

/**
* Maps an operations parameters to their shapes
*/
private fun mapOperationParametersToModeledShapes(operation: OperationShape): Map<String, Shape> =
model.getShape(operation.inputShape).get().allMembers.map { (key, value) ->
key.toCamelCase() to model.getShape(value.target).get()
}.toMap()

/**
* Derives a function name for a [SmokeTestCase]
*/
Expand Down Expand Up @@ -345,8 +274,12 @@ class SmokeTestsRunnerGenerator(
/**
* Get the client configuration required for a [SmokeTestCase]
*/
private val SmokeTestCase.clientConfig: MutableMap<StringNode, Node>?
get() = this.vendorParams.get().members
private val SmokeTestCase.clientConfig: Map<String, Node>
get() = vendorParams
.getOrNull()
?.members
?.mapKeys { (key, _) -> key.value }
.orEmpty()

// Constants
private val model = ctx.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ software.amazon.smithy.kotlin.codegen.rendering.endpoints.discovery.EndpointDisc
software.amazon.smithy.kotlin.codegen.rendering.endpoints.SdkEndpointBuiltinIntegration
software.amazon.smithy.kotlin.codegen.rendering.compression.RequestCompressionIntegration
software.amazon.smithy.kotlin.codegen.rendering.auth.SigV4AsymmetricAuthSchemeIntegration
# software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestsIntegration
software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestsIntegration
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class SmokeTestsRunnerGeneratorTest {
}
vendorParamsShape: AwsVendorParams,
vendorParams: {
region: "eu-central-1"
region: "eu-central-1",
uri: "https://foo.amazon.com"
}
}
{
Expand Down Expand Up @@ -106,12 +107,13 @@ class SmokeTestsRunnerGeneratorTest {
}
try {
com.test.TestClient {
TestClient {
interceptors.add(SmokeTestsInterceptor())
region = "eu-central-1"
uri = "https://foo.amazon.com"
}.use { client ->
client.testOperation(
com.test.model.TestOperationRequest {
TestOperationRequest {
bar = "2"
}
)
Expand Down Expand Up @@ -143,11 +145,10 @@ class SmokeTestsRunnerGeneratorTest {
}
try {
com.test.TestClient {
TestClient {
}.use { client ->
client.testOperation(
com.test.model.TestOperationRequest {
TestOperationRequest {
bar = "föö"
}
)
Expand Down Expand Up @@ -179,12 +180,11 @@ class SmokeTestsRunnerGeneratorTest {
}
try {
com.test.TestClient {
TestClient {
interceptors.add(SmokeTestsInterceptor())
}.use { client ->
client.testOperation(
com.test.model.TestOperationRequest {
TestOperationRequest {
bar = "föö"
}
)
Expand Down

0 comments on commit 9e1a8d6

Please sign in to comment.