Skip to content

Commit

Permalink
feat: type safe argument parsing using reflection
Browse files Browse the repository at this point in the history
  • Loading branch information
jenspots committed Jul 4, 2024
1 parent 09e9a48 commit de311d1
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 0 deletions.
130 changes: 130 additions & 0 deletions src/main/kotlin/runner/jvm/Arguments.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package technology.idlab.runner.jvm

import kotlin.reflect.KType
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.isSuperclassOf
import kotlin.reflect.jvm.jvmErasure
import kotlin.reflect.typeOf
import technology.idlab.util.Log

/**
* Recursively check if a value corresponds to a given KType. This function can be run either
* strictly or loosely. This is due to type erasure, which means we must cast the value to the KType
* and deal with type parameters manually. For example, for Pair<A, B>, we call the function
* recursively for both the `first` as `second` data field. For List<T>, we call the function for
* each element of the list. If a given container is not supported in that fashion, the result will
* depend on the strict parameter.
*/
fun safeCast(to: KType, from: Any, strict: Boolean = false): Boolean {
// The base case, where the Any type matches everything.
if (to.jvmErasure == Any::class) {
return true
}

// The requested type must be an actual superclass of the value given.
if (!to.jvmErasure.isSuperclassOf(from::class)) {
return false
}

// Retrieve the type arguments. If these are empty, then we can safely assume that the type is
// cast correctly and safely.
val typeArguments = to.arguments
if (typeArguments.isEmpty()) {
return true
}

// If the value is pair, check both first and second.
if (to.jvmErasure == Pair::class) {
if (!from::class.isSubclassOf(Pair::class)) {
return false
}

// Extract pair and the type arguments.
@Suppress("UNCHECKED_CAST") val pair = from as Pair<Any, Any>
val first = to.arguments[0].type!!
val second = to.arguments[1].type!!

return safeCast(first, pair.first) && safeCast(second, pair.second)
}

// If the value is a list, check all elements.
if (to.jvmErasure == List::class) {
if (!from::class.isSubclassOf(List::class)) {
return false
}

// Extract values.
@Suppress("UNCHECKED_CAST") val list = from as List<Any>
val elementType = to.arguments[0].type!!
return list.all { safeCast(elementType, it, strict) }
}

// We will never be able to exhaustively go over all types, due to type erasure. However, we're
// if the user is okay with non-strict type checking, we may end here.
return !strict
}

data class Arguments(
val args: Map<String, List<Any>>,
) {
/**
* Get an argument in a type safe way. The type parameter, either inferred or explicitly given,
* will be used to recursively check the resulting type. Note that if you want to retrieve an
* argument with type T which has Argument.Count.REQUIRED, you can either request type T directly
* or the list with one element using the List<T> type.
*/
inline operator fun <reified T> get(name: String, strict: Boolean = false): T {
val type = typeOf<T>()

// Retrieve the value from the map.
val argumentList =
this.args[name]
?: if (type.isMarkedNullable) {
return null as T
} else {
Log.shared.fatal("Argument $name is missing")
}

// Special case: check if the type is not a list, because in that case, we would need to get the
// first element instead.
val arg =
if (T::class.isSuperclassOf(List::class)) {
argumentList
} else {
if (argumentList.size != 1) {
Log.shared.fatal("Cannot obtain single argument if there is not exactly one value.")
}

argumentList[0]
}

if (safeCast(type, arg, strict)) {
return arg as T
} else {
Log.shared.fatal("Could not parse $name to ${T::class.simpleName}")
}
}

companion object {
/**
* Parse a (nested) map into type-safe arguments. This method calls itself recursively for all
* values which are maps as well.
*/
fun from(args: Map<String, List<Any>>): Arguments {
return Arguments(
args.mapValues { (_, list) ->
list.map { arg ->
if (arg::class.isSubclassOf(Map::class)) {
if (safeCast(typeOf<Map<String, List<Any>>>(), arg)) {
@Suppress("UNCHECKED_CAST") Arguments.from(arg as Map<String, List<Any>>)
} else {
Log.shared.fatal("Cannot have raw maps in arguments.")
}
} else {
arg
}
}
})
}
}
}
125 changes: 125 additions & 0 deletions src/test/kotlin/runner/jvm/ArgumentsTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package runner.jvm

import kotlin.test.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import technology.idlab.exception.RunnerException
import technology.idlab.runner.jvm.Arguments

class ArgumentsTest {
@Test
fun single() {
val args = Arguments(mapOf("key" to listOf("value")))
assertEquals("value", args.get<String>("key"))
}

@Test
fun notSingle() {
val args = Arguments(mapOf("key" to listOf("value1", "value2")))
assertThrows<RunnerException> { args.get<String>("key") }
}

@Test
fun singleList() {
val args = Arguments(mapOf("key" to listOf("value")))
assertEquals(listOf("value"), args.get<List<String>>("key"))
}

@Test
fun longList() {
val args = Arguments(mapOf("key" to listOf("value1", "value2")))
assertEquals(listOf("value1", "value2"), args.get<List<String>>("key"))
}

@Test
fun longListWrong() {
val args = Arguments(mapOf("key" to listOf("value1", "value2")))

assertThrows<RunnerException> { args.get<List<Int>>("key", strict = true) }
}

@Test
fun nullable() {
val args = Arguments(mapOf())
assertEquals(null, args.get<String?>("key"))
}

@Test
fun nonNullable() {
val args = Arguments(mapOf())
assertThrows<RunnerException> { args.get<String>("key") }
}

@Test
fun invalidCast() {
val args = Arguments(mapOf("key" to listOf("value")))
assertThrows<RunnerException> { args.get<Int>("key") }
}

@Test
fun pairs() {
val args = Arguments(mapOf("first" to listOf(Pair(1, "a")), "second" to listOf(Pair(2, "b"))))

// Get first pair correctly.
val first = args.get<Pair<Int, String>>("first")
assertEquals(1, first.first)
assertEquals("a", first.second)

// Get second pair correctly, use operator syntax.
val second: Pair<Int, String> = args["second"]
assertEquals(2, second.first)
assertEquals("b", second.second)

// Get first pair as a list.
val firstList = args.get<List<Pair<Int, String>>>("first")
assertEquals(1, firstList[0].first)
assertEquals("a", firstList[0].second)

// Same for second, use operator syntax.
val secondList: List<Pair<Int, String>> = args["second"]
assertEquals(2, secondList[0].first)
assertEquals("b", secondList[0].second)

// Attempt to get integer as double, in strict mode.
assertThrows<RunnerException> { args.get<Pair<Double, String>>("first", strict = true) }

// Attempt to get string as integer.
assertThrows<RunnerException> { args.get<Pair<Int, Int>>("first") }
}

@Test
fun nested() {
val args = Arguments.from(mapOf("root" to listOf(mapOf("leaf" to listOf("Hello, World!")))))

val value = args.get<Arguments>("root").get<String>("leaf")
assertEquals("Hello, World!", value)
}

@Test
fun inheritance() {
// The base class.
open class A

// The extended class.
open class B : A()

// The extended, extended class.
class C : B()

// Create three arguments, each with the lists.
val args =
Arguments(mapOf("a" to listOf(A(), A()), "b" to listOf(B(), B()), "c" to listOf(C(), C())))

assertEquals(2, args.get<List<A>>("a", strict = true).size)
assertEquals(2, args.get<List<A>>("b", strict = true).size)
assertEquals(2, args.get<List<A>>("c", strict = true).size)

assertThrows<RunnerException> { args.get<List<B>>("a", strict = true) }
assertEquals(2, args.get<List<B>>("b", strict = true).size)
assertEquals(2, args.get<List<B>>("c", strict = true).size)

assertThrows<RunnerException> { args.get<List<C>>("a", strict = true) }
assertThrows<RunnerException> { args.get<List<C>>("b", strict = true) }
assertEquals(2, args.get<List<C>>("c", strict = true).size)
}
}

0 comments on commit de311d1

Please sign in to comment.