diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt index 2bae5561b..4fa98957b 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt @@ -26,8 +26,6 @@ import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.Shape -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract private val suffixSequence = sequenceOf("") + generateSequence(2) { it + 1 }.map(Int::toString) // "", "2", "3", etc. @@ -76,14 +74,6 @@ class KotlinJmespathExpressionVisitor( private fun bestTempVarName(preferredName: String): String = suffixSequence.map { "$preferredName$it" }.first(tempVars::add) - @OptIn(ExperimentalContracts::class) - private fun codegenReq(condition: Boolean, lazyMessage: () -> String) { - contract { - returns() implies condition - } - if (!condition) throw CodegenException(lazyMessage()) - } - private fun flatMappingBlock(right: JmespathExpression, leftName: String, leftShape: Shape, innerShape: Shape?): VisitedExpression { if (right is CurrentExpression) return VisitedExpression(leftName, leftShape) // nothing to map @@ -161,24 +151,21 @@ class KotlinJmespathExpressionVisitor( val codegen = buildString { val nullables = buildList { - if (left.shape?.isNullable == true) add("${left.identifier} == null") - if (right.shape?.isNullable == true) add("${right.identifier} == null") + if (left.shape?.isNullable == true || left.nullable) add("${left.identifier} == null") + if (right.shape?.isNullable == true || right.nullable) add("${right.identifier} == null") } + if (nullables.isNotEmpty()) { val isNullExpr = nullables.joinToString(" || ") append("if ($isNullExpr) null else ") } - val unSafeComparatorExpr = "compareTo(${right.identifier})" - val comparatorExpr = if (left.nullable) "?.$unSafeComparatorExpr" else ".$unSafeComparatorExpr" - + val comparatorExpr = ".compareTo(${right.identifier}) ${expression.comparator} 0" append("${left.identifier}$comparatorExpr") } - val unSafeComparatorValue = addTempVar("unSafeComparator", codegen) - val safeComparatorValue = addTempVar("safeComparator", "if ($unSafeComparatorValue == null) null else $unSafeComparatorValue ${expression.comparator} 0") - - return VisitedExpression(safeComparatorValue) + val identifier = addTempVar("comparison", codegen) + return VisitedExpression(identifier) } override fun visitCurrentNode(expression: CurrentExpression): VisitedExpression { @@ -219,14 +206,11 @@ class KotlinJmespathExpressionVisitor( return VisitedExpression(ident, currentShape, inner.projected) } - private fun FunctionExpression.singleArg(): VisitedExpression { - codegenReq(arguments.size == 1) { "Unexpected number of arguments to $this" } - return acceptSubexpression(this.arguments[0]) - } - private fun FunctionExpression.twoArgs(): Pair { - codegenReq(arguments.size == 2) { "Unexpected number of arguments to $this" } - return acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1]) - } + private fun FunctionExpression.singleArg(): VisitedExpression = + acceptSubexpression(this.arguments[0]) + + private fun FunctionExpression.twoArgs(): Pair = + acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1]) private fun FunctionExpression.args(): List = this.arguments.map { acceptSubexpression(it) } @@ -345,35 +329,37 @@ class KotlinJmespathExpressionVisitor( } "sort" -> { - writer.addImport(RuntimeTypes.Core.Utils.jmespathSort) val arg = expression.singleArg() - arg.dotFunction(expression, "jmespathSort()") + arg.dotFunction(expression, "sorted()") } - "sort_by" -> mappingFunction(expression, "sortBy", this::sortBy) + "sort_by" -> { + val list = expression.arguments[0].accept(this) + val expressionValue = expression.arguments[1] + list.applyFunction(expression.name.toCamelCase(), "sortedBy", expressionValue) + } - "max_by" -> mappingFunction(expression, "maxBy", this::maxBy) + "max_by" -> { + val list = expression.arguments[0].accept(this) + val expressionValue = expression.arguments[1] + list.applyFunction(expression.name.toCamelCase(), "maxBy", expressionValue) + } - "min_by" -> mappingFunction(expression, "minBy", this::minBy) + "min_by" -> { + val list = expression.arguments[0].accept(this) + val expressionValue = expression.arguments[1] + list.applyFunction(expression.name.toCamelCase(), "minBy", expressionValue) + } - "map" -> mappingFunction(expression, "map", this::map, true) + "map" -> { + val list = expression.arguments[1].accept(this) + val expressionValue = expression.arguments[0] + list.applyFunction(expression.name.toCamelCase(), "map", expressionValue) + } else -> throw CodegenException("Unknown function type in $expression") } - private fun mappingFunction( - expression: FunctionExpression, - variableName: String, - function: (JmespathExpression, VisitedExpression) -> String, - invertedArgs: Boolean = false, - ): VisitedExpression { - val (argIndex, expressionIndex) = if (invertedArgs) 1 to 0 else 0 to 1 - - codegenReq(expression.arguments.size == 2) { "Unexpected number of arguments to $this" } - val arg = expression.arguments[argIndex].accept(this) - return VisitedExpression(addTempVar(variableName, function(expression.arguments[expressionIndex], arg))) - } - override fun visitIndex(expression: IndexExpression): VisitedExpression { throw CodegenException("IndexExpression is unsupported") } @@ -585,36 +571,20 @@ class KotlinJmespathExpressionVisitor( return notNull } - private fun mapFunctionLogic( - expression: JmespathExpression, - arg: VisitedExpression, - resultName: String, + private fun VisitedExpression.applyFunction( + name: String, operation: String, - stringOrNumberCheck: Boolean = true, - ): String { - val argName = arg.identifier - val result = bestTempVarName(resultName) - writer.withBlock("val $result = $argName?.$operation{", "}") { - val expressionValue = addTempVar("expression", subfieldCodegen((expression as ExpressionTypeExpression).expression as FieldExpression, "it")) - if (stringOrNumberCheck) { - write("if ($expressionValue as Any !is Number && $expressionValue as Any !is String) throw Exception(\"Result of applying expression should be string or number\")") - } - write(expressionValue) - } - return result - } - - private fun sortBy(expression: JmespathExpression, arg: VisitedExpression): String = - mapFunctionLogic(expression, arg, "sorted", "sortedBy") - - private fun maxBy(expression: JmespathExpression, arg: VisitedExpression): String = - mapFunctionLogic(expression, arg, "max", "maxBy") + expression: JmespathExpression, + ): VisitedExpression { + val result = bestTempVarName(name) - private fun minBy(expression: JmespathExpression, arg: VisitedExpression): String = - mapFunctionLogic(expression, arg, "min", "minBy") + writer.withBlock("val $result = ${this.identifier}?.$operation {", "}") { + val expressionValue = subfieldCodegen((expression as ExpressionTypeExpression).expression as FieldExpression, "it") + write("$expressionValue!!") + } - private fun map(expression: JmespathExpression, arg: VisitedExpression): String = - mapFunctionLogic(expression, arg, "map", "map", false) + return VisitedExpression(result) + } private val Shape.isNullable: Boolean get() = this is MemberShape && diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/JMESPath.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/JMESPath.kt index 13f28af5d..386fb17b9 100644 --- a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/JMESPath.kt +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/JMESPath.kt @@ -72,27 +72,3 @@ public fun Any?.type(): String = when (this) { null -> "null" else -> throw Exception("Undetected type for: $this") } - -@InternalApi -@JvmName("StringJmespathSort") -public fun List.jmespathSort(): List = this.sorted() - -@InternalApi -@JvmName("ShortJmespathSort") -public fun List.jmespathSort(): List = this.sorted() - -@InternalApi -@JvmName("IntJmespathSort") -public fun List.jmespathSort(): List = this.sorted() - -@InternalApi -@JvmName("FloatJmespathSort") -public fun List.jmespathSort(): List = this.sorted() - -@InternalApi -@JvmName("LongJmespathSort") -public fun List.jmespathSort(): List = this.sorted() - -@InternalApi -@JvmName("DoubleJmespathSort") -public fun List.jmespathSort(): List = this.sorted()