-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: type safe argument parsing using reflection
- Loading branch information
Showing
2 changed files
with
255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package technology.idlab.runner.jvm | ||
|
||
import kotlin.reflect.KType | ||
import kotlin.reflect.full.isSubclassOf | ||
import kotlin.reflect.full.isSuperclassOf | ||
import kotlin.reflect.jvm.jvmErasure | ||
import kotlin.reflect.typeOf | ||
import technology.idlab.util.Log | ||
|
||
/** | ||
* Recursively check if a value corresponds to a given KType. This function can be run either | ||
* strictly or loosely. This is due to type erasure, which means we must cast the value to the KType | ||
* and deal with type parameters manually. For example, for Pair<A, B>, we call the function | ||
* recursively for both the `first` as `second` data field. For List<T>, we call the function for | ||
* each element of the list. If a given container is not supported in that fashion, the result will | ||
* depend on the strict parameter. | ||
*/ | ||
fun safeCast(to: KType, from: Any, strict: Boolean = false): Boolean { | ||
// The base case, where the Any type matches everything. | ||
if (to.jvmErasure == Any::class) { | ||
return true | ||
} | ||
|
||
// The requested type must be an actual superclass of the value given. | ||
if (!to.jvmErasure.isSuperclassOf(from::class)) { | ||
return false | ||
} | ||
|
||
// Retrieve the type arguments. If these are empty, then we can safely assume that the type is | ||
// cast correctly and safely. | ||
val typeArguments = to.arguments | ||
if (typeArguments.isEmpty()) { | ||
return true | ||
} | ||
|
||
// If the value is pair, check both first and second. | ||
if (to.jvmErasure == Pair::class) { | ||
if (!from::class.isSubclassOf(Pair::class)) { | ||
return false | ||
} | ||
|
||
// Extract pair and the type arguments. | ||
@Suppress("UNCHECKED_CAST") val pair = from as Pair<Any, Any> | ||
val first = to.arguments[0].type!! | ||
val second = to.arguments[1].type!! | ||
|
||
return safeCast(first, pair.first) && safeCast(second, pair.second) | ||
} | ||
|
||
// If the value is a list, check all elements. | ||
if (to.jvmErasure == List::class) { | ||
if (!from::class.isSubclassOf(List::class)) { | ||
return false | ||
} | ||
|
||
// Extract values. | ||
@Suppress("UNCHECKED_CAST") val list = from as List<Any> | ||
val elementType = to.arguments[0].type!! | ||
return list.all { safeCast(elementType, it, strict) } | ||
} | ||
|
||
// We will never be able to exhaustively go over all types, due to type erasure. However, we're | ||
// if the user is okay with non-strict type checking, we may end here. | ||
return !strict | ||
} | ||
|
||
data class Arguments( | ||
val args: Map<String, List<Any>>, | ||
) { | ||
/** | ||
* Get an argument in a type safe way. The type parameter, either inferred or explicitly given, | ||
* will be used to recursively check the resulting type. Note that if you want to retrieve an | ||
* argument with type T which has Argument.Count.REQUIRED, you can either request type T directly | ||
* or the list with one element using the List<T> type. | ||
*/ | ||
inline operator fun <reified T> get(name: String, strict: Boolean = false): T { | ||
val type = typeOf<T>() | ||
|
||
// Retrieve the value from the map. | ||
val argumentList = | ||
this.args[name] | ||
?: if (type.isMarkedNullable) { | ||
return null as T | ||
} else { | ||
Log.shared.fatal("Argument $name is missing") | ||
} | ||
|
||
// Special case: check if the type is not a list, because in that case, we would need to get the | ||
// first element instead. | ||
val arg = | ||
if (T::class.isSuperclassOf(List::class)) { | ||
argumentList | ||
} else { | ||
if (argumentList.size != 1) { | ||
Log.shared.fatal("Cannot obtain single argument if there is not exactly one value.") | ||
} | ||
|
||
argumentList[0] | ||
} | ||
|
||
if (safeCast(type, arg, strict)) { | ||
return arg as T | ||
} else { | ||
Log.shared.fatal("Could not parse $name to ${T::class.simpleName}") | ||
} | ||
} | ||
|
||
companion object { | ||
/** | ||
* Parse a (nested) map into type-safe arguments. This method calls itself recursively for all | ||
* values which are maps as well. | ||
*/ | ||
fun from(args: Map<String, List<Any>>): Arguments { | ||
return Arguments( | ||
args.mapValues { (_, list) -> | ||
list.map { arg -> | ||
if (arg::class.isSubclassOf(Map::class)) { | ||
if (safeCast(typeOf<Map<String, List<Any>>>(), arg)) { | ||
@Suppress("UNCHECKED_CAST") Arguments.from(arg as Map<String, List<Any>>) | ||
} else { | ||
Log.shared.fatal("Cannot have raw maps in arguments.") | ||
} | ||
} else { | ||
arg | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
package runner.jvm | ||
|
||
import kotlin.test.assertEquals | ||
import org.junit.jupiter.api.Test | ||
import org.junit.jupiter.api.assertThrows | ||
import technology.idlab.exception.RunnerException | ||
import technology.idlab.runner.jvm.Arguments | ||
|
||
class ArgumentsTest { | ||
@Test | ||
fun single() { | ||
val args = Arguments(mapOf("key" to listOf("value"))) | ||
assertEquals("value", args.get<String>("key")) | ||
} | ||
|
||
@Test | ||
fun notSingle() { | ||
val args = Arguments(mapOf("key" to listOf("value1", "value2"))) | ||
assertThrows<RunnerException> { args.get<String>("key") } | ||
} | ||
|
||
@Test | ||
fun singleList() { | ||
val args = Arguments(mapOf("key" to listOf("value"))) | ||
assertEquals(listOf("value"), args.get<List<String>>("key")) | ||
} | ||
|
||
@Test | ||
fun longList() { | ||
val args = Arguments(mapOf("key" to listOf("value1", "value2"))) | ||
assertEquals(listOf("value1", "value2"), args.get<List<String>>("key")) | ||
} | ||
|
||
@Test | ||
fun longListWrong() { | ||
val args = Arguments(mapOf("key" to listOf("value1", "value2"))) | ||
|
||
assertThrows<RunnerException> { args.get<List<Int>>("key", strict = true) } | ||
} | ||
|
||
@Test | ||
fun nullable() { | ||
val args = Arguments(mapOf()) | ||
assertEquals(null, args.get<String?>("key")) | ||
} | ||
|
||
@Test | ||
fun nonNullable() { | ||
val args = Arguments(mapOf()) | ||
assertThrows<RunnerException> { args.get<String>("key") } | ||
} | ||
|
||
@Test | ||
fun invalidCast() { | ||
val args = Arguments(mapOf("key" to listOf("value"))) | ||
assertThrows<RunnerException> { args.get<Int>("key") } | ||
} | ||
|
||
@Test | ||
fun pairs() { | ||
val args = Arguments(mapOf("first" to listOf(Pair(1, "a")), "second" to listOf(Pair(2, "b")))) | ||
|
||
// Get first pair correctly. | ||
val first = args.get<Pair<Int, String>>("first") | ||
assertEquals(1, first.first) | ||
assertEquals("a", first.second) | ||
|
||
// Get second pair correctly, use operator syntax. | ||
val second: Pair<Int, String> = args["second"] | ||
assertEquals(2, second.first) | ||
assertEquals("b", second.second) | ||
|
||
// Get first pair as a list. | ||
val firstList = args.get<List<Pair<Int, String>>>("first") | ||
assertEquals(1, firstList[0].first) | ||
assertEquals("a", firstList[0].second) | ||
|
||
// Same for second, use operator syntax. | ||
val secondList: List<Pair<Int, String>> = args["second"] | ||
assertEquals(2, secondList[0].first) | ||
assertEquals("b", secondList[0].second) | ||
|
||
// Attempt to get integer as double, in strict mode. | ||
assertThrows<RunnerException> { args.get<Pair<Double, String>>("first", strict = true) } | ||
|
||
// Attempt to get string as integer. | ||
assertThrows<RunnerException> { args.get<Pair<Int, Int>>("first") } | ||
} | ||
|
||
@Test | ||
fun nested() { | ||
val args = Arguments.from(mapOf("root" to listOf(mapOf("leaf" to listOf("Hello, World!"))))) | ||
|
||
val value = args.get<Arguments>("root").get<String>("leaf") | ||
assertEquals("Hello, World!", value) | ||
} | ||
|
||
@Test | ||
fun inheritance() { | ||
// The base class. | ||
open class A | ||
|
||
// The extended class. | ||
open class B : A() | ||
|
||
// The extended, extended class. | ||
class C : B() | ||
|
||
// Create three arguments, each with the lists. | ||
val args = | ||
Arguments(mapOf("a" to listOf(A(), A()), "b" to listOf(B(), B()), "c" to listOf(C(), C()))) | ||
|
||
assertEquals(2, args.get<List<A>>("a", strict = true).size) | ||
assertEquals(2, args.get<List<A>>("b", strict = true).size) | ||
assertEquals(2, args.get<List<A>>("c", strict = true).size) | ||
|
||
assertThrows<RunnerException> { args.get<List<B>>("a", strict = true) } | ||
assertEquals(2, args.get<List<B>>("b", strict = true).size) | ||
assertEquals(2, args.get<List<B>>("c", strict = true).size) | ||
|
||
assertThrows<RunnerException> { args.get<List<C>>("a", strict = true) } | ||
assertThrows<RunnerException> { args.get<List<C>>("b", strict = true) } | ||
assertEquals(2, args.get<List<C>>("c", strict = true).size) | ||
} | ||
} |