Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
0marperez committed Oct 5, 2023
1 parent 7e06e45 commit 10c8154
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract

private val suffixSequence = sequenceOf("") + generateSequence(2) { it + 1 }.map(Int::toString) // "", "2", "3", etc.

Expand Down Expand Up @@ -76,14 +74,6 @@ class KotlinJmespathExpressionVisitor(
private fun bestTempVarName(preferredName: String): String =
suffixSequence.map { "$preferredName$it" }.first(tempVars::add)

@OptIn(ExperimentalContracts::class)
private fun codegenReq(condition: Boolean, lazyMessage: () -> String) {
contract {
returns() implies condition
}
if (!condition) throw CodegenException(lazyMessage())
}

private fun flatMappingBlock(right: JmespathExpression, leftName: String, leftShape: Shape, innerShape: Shape?): VisitedExpression {
if (right is CurrentExpression) return VisitedExpression(leftName, leftShape) // nothing to map

Expand Down Expand Up @@ -161,24 +151,21 @@ class KotlinJmespathExpressionVisitor(

val codegen = buildString {
val nullables = buildList {
if (left.shape?.isNullable == true) add("${left.identifier} == null")
if (right.shape?.isNullable == true) add("${right.identifier} == null")
if (left.shape?.isNullable == true || left.nullable) add("${left.identifier} == null")
if (right.shape?.isNullable == true || right.nullable) add("${right.identifier} == null")
}

if (nullables.isNotEmpty()) {
val isNullExpr = nullables.joinToString(" || ")
append("if ($isNullExpr) null else ")
}

val unSafeComparatorExpr = "compareTo(${right.identifier})"
val comparatorExpr = if (left.nullable) "?.$unSafeComparatorExpr" else ".$unSafeComparatorExpr"

val comparatorExpr = ".compareTo(${right.identifier}) ${expression.comparator} 0"
append("${left.identifier}$comparatorExpr")
}

val unSafeComparatorValue = addTempVar("unSafeComparator", codegen)
val safeComparatorValue = addTempVar("safeComparator", "if ($unSafeComparatorValue == null) null else $unSafeComparatorValue ${expression.comparator} 0")

return VisitedExpression(safeComparatorValue)
val identifier = addTempVar("comparison", codegen)
return VisitedExpression(identifier)
}

override fun visitCurrentNode(expression: CurrentExpression): VisitedExpression {
Expand Down Expand Up @@ -219,14 +206,11 @@ class KotlinJmespathExpressionVisitor(
return VisitedExpression(ident, currentShape, inner.projected)
}

private fun FunctionExpression.singleArg(): VisitedExpression {
codegenReq(arguments.size == 1) { "Unexpected number of arguments to $this" }
return acceptSubexpression(this.arguments[0])
}
private fun FunctionExpression.twoArgs(): Pair<VisitedExpression, VisitedExpression> {
codegenReq(arguments.size == 2) { "Unexpected number of arguments to $this" }
return acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1])
}
private fun FunctionExpression.singleArg(): VisitedExpression =
acceptSubexpression(this.arguments[0])

private fun FunctionExpression.twoArgs(): Pair<VisitedExpression, VisitedExpression> =
acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1])

private fun FunctionExpression.args(): List<VisitedExpression> =
this.arguments.map { acceptSubexpression(it) }
Expand Down Expand Up @@ -345,35 +329,37 @@ class KotlinJmespathExpressionVisitor(
}

"sort" -> {
writer.addImport(RuntimeTypes.Core.Utils.jmespathSort)
val arg = expression.singleArg()
arg.dotFunction(expression, "jmespathSort()")
arg.dotFunction(expression, "sorted()")
}

"sort_by" -> mappingFunction(expression, "sortBy", this::sortBy)
"sort_by" -> {
val list = expression.arguments[0].accept(this)
val expressionValue = expression.arguments[1]
list.applyFunction(expression.name.toCamelCase(), "sortedBy", expressionValue)
}

"max_by" -> mappingFunction(expression, "maxBy", this::maxBy)
"max_by" -> {
val list = expression.arguments[0].accept(this)
val expressionValue = expression.arguments[1]
list.applyFunction(expression.name.toCamelCase(), "maxBy", expressionValue)
}

"min_by" -> mappingFunction(expression, "minBy", this::minBy)
"min_by" -> {
val list = expression.arguments[0].accept(this)
val expressionValue = expression.arguments[1]
list.applyFunction(expression.name.toCamelCase(), "minBy", expressionValue)
}

"map" -> mappingFunction(expression, "map", this::map, true)
"map" -> {
val list = expression.arguments[1].accept(this)
val expressionValue = expression.arguments[0]
list.applyFunction(expression.name.toCamelCase(), "map", expressionValue)
}

else -> throw CodegenException("Unknown function type in $expression")
}

private fun mappingFunction(
expression: FunctionExpression,
variableName: String,
function: (JmespathExpression, VisitedExpression) -> String,
invertedArgs: Boolean = false,
): VisitedExpression {
val (argIndex, expressionIndex) = if (invertedArgs) 1 to 0 else 0 to 1

codegenReq(expression.arguments.size == 2) { "Unexpected number of arguments to $this" }
val arg = expression.arguments[argIndex].accept(this)
return VisitedExpression(addTempVar(variableName, function(expression.arguments[expressionIndex], arg)))
}

override fun visitIndex(expression: IndexExpression): VisitedExpression {
throw CodegenException("IndexExpression is unsupported")
}
Expand Down Expand Up @@ -585,36 +571,20 @@ class KotlinJmespathExpressionVisitor(
return notNull
}

private fun mapFunctionLogic(
expression: JmespathExpression,
arg: VisitedExpression,
resultName: String,
private fun VisitedExpression.applyFunction(
name: String,
operation: String,
stringOrNumberCheck: Boolean = true,
): String {
val argName = arg.identifier
val result = bestTempVarName(resultName)
writer.withBlock("val $result = $argName?.$operation{", "}") {
val expressionValue = addTempVar("expression", subfieldCodegen((expression as ExpressionTypeExpression).expression as FieldExpression, "it"))
if (stringOrNumberCheck) {
write("if ($expressionValue as Any !is Number && $expressionValue as Any !is String) throw Exception(\"Result of applying expression should be string or number\")")
}
write(expressionValue)
}
return result
}

private fun sortBy(expression: JmespathExpression, arg: VisitedExpression): String =
mapFunctionLogic(expression, arg, "sorted", "sortedBy")

private fun maxBy(expression: JmespathExpression, arg: VisitedExpression): String =
mapFunctionLogic(expression, arg, "max", "maxBy")
expression: JmespathExpression,
): VisitedExpression {
val result = bestTempVarName(name)

private fun minBy(expression: JmespathExpression, arg: VisitedExpression): String =
mapFunctionLogic(expression, arg, "min", "minBy")
writer.withBlock("val $result = ${this.identifier}?.$operation {", "}") {
val expressionValue = subfieldCodegen((expression as ExpressionTypeExpression).expression as FieldExpression, "it")
write("$expressionValue!!")
}

private fun map(expression: JmespathExpression, arg: VisitedExpression): String =
mapFunctionLogic(expression, arg, "map", "map", false)
return VisitedExpression(result)
}

private val Shape.isNullable: Boolean
get() = this is MemberShape &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,3 @@ public fun Any?.type(): String = when (this) {
null -> "null"
else -> throw Exception("Undetected type for: $this")
}

@InternalApi
@JvmName("StringJmespathSort")
public fun List<String>.jmespathSort(): List<String> = this.sorted()

@InternalApi
@JvmName("ShortJmespathSort")
public fun List<Short>.jmespathSort(): List<Short> = this.sorted()

@InternalApi
@JvmName("IntJmespathSort")
public fun List<Int>.jmespathSort(): List<Int> = this.sorted()

@InternalApi
@JvmName("FloatJmespathSort")
public fun List<Float>.jmespathSort(): List<Float> = this.sorted()

@InternalApi
@JvmName("LongJmespathSort")
public fun List<Long>.jmespathSort(): List<Long> = this.sorted()

@InternalApi
@JvmName("DoubleJmespathSort")
public fun List<Double>.jmespathSort(): List<Double> = this.sorted()

0 comments on commit 10c8154

Please sign in to comment.