Skip to content

Commit

Permalink
Adding a test-helper class TestCoroutineContext.
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Spaans committed May 7, 2018
1 parent 20dbd9f commit 289f3ba
Show file tree
Hide file tree
Showing 4 changed files with 723 additions and 0 deletions.
10 changes: 10 additions & 0 deletions core/kotlinx-coroutines-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ This module provides debugging facilities for coroutines (run JVM with `-ea` or
and [newCoroutineContext] function to write user-defined coroutine builders that work with these
debugging facilities.

This module provides a special CoroutineContext type [TestCoroutineCoroutineContext][kotlinx.coroutines.experimental.test.TestCoroutineContext] that
allows the writer of code that contains Coroutines with delays and timeouts to write non-flaky unit-tests for that code allowing these tests to
terminate in near zero time. See the documentation for this class for more information.

# Package kotlinx.coroutines.experimental

General-purpose coroutine builders, contexts, and helper functions.
Expand All @@ -93,6 +97,10 @@ Low-level primitives for finer-grained control of coroutines.

Optional time unit support for multiplatform projects.

# Package kotlinx.coroutines.experimental.test

Components to ease writing unit-tests for code that contains coroutines with delays and timeouts.

<!--- MODULE kotlinx-coroutines-core -->
<!--- INDEX kotlinx.coroutines.experimental -->
[launch]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental/launch.html
Expand Down Expand Up @@ -148,4 +156,6 @@ Optional time unit support for multiplatform projects.
<!--- INDEX kotlinx.coroutines.experimental.selects -->
[kotlinx.coroutines.experimental.selects.select]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental.selects/select.html
[kotlinx.coroutines.experimental.selects.SelectBuilder.onTimeout]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental.selects/-select-builder/on-timeout.html
<!--- INDEX kotlinx.coroutines.experimental.test -->
[kotlinx.coroutines.experimental.test.TestCoroutineContext]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental.test/-test-coroutine-context/index.html
<!--- END -->
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package kotlinx.coroutines.experimental.internal

import java.util.*

/**
* @suppress **This is unstable API and it is subject to change.**
*/
Expand All @@ -36,6 +38,11 @@ public class ThreadSafeHeap<T> where T: ThreadSafeHeapNode, T: Comparable<T> {

public val isEmpty: Boolean get() = size == 0

public fun clear() = synchronized(this) {
Arrays.fill(a, 0, size, null)
size = 0
}

public fun peek(): T? = synchronized(this) { firstImpl() }

public fun removeFirstOrNull(): T? = synchronized(this) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
/*
* Copyright 2016-2018 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package kotlinx.coroutines.experimental.test

import kotlinx.coroutines.experimental.*
import kotlinx.coroutines.experimental.internal.*
import java.util.concurrent.TimeUnit
import kotlin.coroutines.experimental.*

/**
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
* code, especially tests, that deal with delays and timeouts in Coroutines.
*
* Provide an instance of this TestCoroutineContext when calling the *non-blocking* [launch] or [async]
* and then advance time or trigger the actions to make the co-routines execute as soon as possible.
*
* This works much like the *TestScheduler* in RxJava2, which allows to speed up tests that deal
* with non-blocking Rx chains that contain delays, timeouts, intervals and such.
*
* This dispatcher can also handle *blocking* coroutines that are started by [runBlocking].
* This dispatcher's virtual time will be automatically advanced based based on the delayed actions
* within the Coroutine(s).
*
* @param name A user-readable name for debugging purposes.
*/
class TestCoroutineContext(private val name: String? = null) : CoroutineContext {
private val uncaughtExceptions = mutableListOf<Throwable>()

private val ctxDispatcher = Dispatcher()

private val ctxHandler = CoroutineExceptionHandler { _, exception ->
uncaughtExceptions += exception
}

// The ordered queue for the runnable tasks.
private val queue = ThreadSafeHeap<TimedRunnable>()

// The per-scheduler global order counter.
private var counter = 0L

// Storing time in nanoseconds internally.
private var time = 0L

/**
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
*/
public val exceptions: List<Throwable> get() = uncaughtExceptions

// -- CoroutineContext implementation

public override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
operation(operation(initial, ctxDispatcher), ctxHandler)

@Suppress("UNCHECKED_CAST")
public override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = when {
key === ContinuationInterceptor -> ctxDispatcher as E
key === CoroutineExceptionHandler -> ctxHandler as E
else -> null
}

public override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = when {
key === ContinuationInterceptor -> ctxHandler
key === CoroutineExceptionHandler -> ctxDispatcher
else -> this
}

/**
* Returns the current virtual clock-time as it is known to this CoroutineContext.
*
* @param unit The [TimeUnit] in which the clock-time must be returned.
* @return The virtual clock-time
*/
public fun now(unit: TimeUnit = TimeUnit.MILLISECONDS)=
unit.convert(time, TimeUnit.NANOSECONDS)

/**
* Moves the CoroutineContext's virtual clock forward by a specified amount of time.
*
* The returned delay-time can be larger than the specified delay-time if the code
* under test contains *blocking* Coroutines.
*
* @param delayTime The amount of time to move the CoroutineContext's clock forward.
* @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
* @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
*/
public fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS): Long {
val oldTime = time
advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
return unit.convert(time - oldTime, TimeUnit.NANOSECONDS)
}

/**
* Moves the CoroutineContext's clock-time to a particular moment in time.
*
* @param targetTime The point in time to which to move the CoroutineContext's clock.
* @param unit The [TimeUnit] in which [targetTime] is expressed.
*/
fun advanceTimeTo(targetTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
val nanoTime = unit.toNanos(targetTime)
triggerActions(nanoTime)
if (nanoTime > time) time = nanoTime
}

/**
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
* before this CoroutineContext's present virtual clock-time.
*/
public fun triggerActions() = triggerActions(time)

/**
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously
* mess with your coroutines work. This method should usually be called on tear-down of a
* unit test.
*/
public fun cancelAllActions() {
// An 'is-empty' test is required to avoid a NullPointerException in the 'clear()' method
if (!queue.isEmpty) queue.clear()
}

/**
* This method does nothing if there is one unhandled exception that satisfies the given predicate.
* Otherwise it throws an [AssertionError] with the given message.
*
* (this method will clear the list of unhandled exceptions)
*
* @param message Message of the [AssertionError]. Defaults to an empty String.
* @param predicate The predicate that must be satisfied.
*/
public fun assertUnhandledException(message: String = "", predicate: (Throwable) -> Boolean) {
if (uncaughtExceptions.size != 1 || !predicate(uncaughtExceptions[0])) throw AssertionError(message)
uncaughtExceptions.clear()
}

/**
* This method does nothing if there are no unhandled exceptions or all of them satisfy the given predicate.
* Otherwise it throws an [AssertionError] with the given message.
*
* (this method will clear the list of unhandled exceptions)
*
* @param message Message of the [AssertionError]. Defaults to an empty String.
* @param predicate The predicate that must be satisfied.
*/
public fun assertAllUnhandledExceptions(message: String = "", predicate: (Throwable) -> Boolean) {
if (!uncaughtExceptions.all(predicate)) throw AssertionError(message)
uncaughtExceptions.clear()
}

/**
* This method does nothing if one or more unhandled exceptions satisfy the given predicate.
* Otherwise it throws an [AssertionError] with the given message.
*
* (this method will clear the list of unhandled exceptions)
*
* @param message Message of the [AssertionError]. Defaults to an empty String.
* @param predicate The predicate that must be satisfied.
*/
public fun assertAnyUnhandledException(message: String = "", predicate: (Throwable) -> Boolean) {
if (!uncaughtExceptions.any(predicate)) throw AssertionError(message)
uncaughtExceptions.clear()
}

/**
* This method does nothing if the list of unhandled exceptions satisfy the given predicate.
* Otherwise it throws an [AssertionError] with the given message.
*
* (this method will clear the list of unhandled exceptions)
*
* @param message Message of the [AssertionError]. Defaults to an empty String.
* @param predicate The predicate that must be satisfied.
*/
public fun assertExceptions(message: String = "", predicate: (List<Throwable>) -> Boolean) {
if (!predicate(uncaughtExceptions)) throw AssertionError(message)
uncaughtExceptions.clear()
}

private fun post(block: Runnable) =
queue.addLast(TimedRunnable(block, counter++))

private fun postDelayed(block: Runnable, delayTime: Long) =
TimedRunnable(block, counter++, time + TimeUnit.MILLISECONDS.toNanos(delayTime))
.also {
queue.addLast(it)
}

private fun processNextEvent(): Long {
val current = queue.peek()
if (current != null) {
/** Automatically advance time for [EventLoop]-callbacks */
triggerActions(current.time)
}
return if (queue.isEmpty) Long.MAX_VALUE else 0L
}

private fun triggerActions(targetTime: Long) {
while (true) {
val current = queue.removeFirstIf { it.time <= targetTime } ?: break
// If the scheduled time is 0 (immediate) use current virtual time
if (current.time != 0L) time = current.time
current.run()
}
}

public override fun toString(): String = name ?: "TestCoroutineContext@$hexAddress"

private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop {
override fun dispatch(context: CoroutineContext, block: Runnable) = post(block)

override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) {
postDelayed(Runnable {
with(continuation) { resumeUndispatched(Unit) }
}, unit.toMillis(time))
}

override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
val node = postDelayed(block, unit.toMillis(time))
return object : DisposableHandle {
override fun dispose() {
queue.remove(node)
}
}
}

override fun processNextEvent() = this@TestCoroutineContext.processNextEvent()

public override fun toString(): String = "Dispatcher(${this@TestCoroutineContext})"
}
}

private class TimedRunnable(
private val run: Runnable,
private val count: Long = 0,
@JvmField internal val time: Long = 0
) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
override var index: Int = 0

override fun run() = run.run()

override fun compareTo(other: TimedRunnable) = if (time == other.time) {
count.compareTo(other.count)
} else {
time.compareTo(other.time)
}

override fun toString() = "TimedRunnable(time=$time, run=$run)"
}

/**
* Executes a block of code in which a unit-test can be written using the provided [TestCoroutineContext]. The provided
* [TestCoroutineContext] is available in the [testBody] as the `this` receiver.
*
* The [testBody] is executed and an [AssertionError] is thrown if the list of unhandled exceptions is not empty and
* contains any exception that is not a [CancellationException].
*
* If the [testBody] successfully executes one of the [TestCoroutineContext.assertAllUnhandledExceptions],
* [TestCoroutineContext.assertAnyUnhandledException], [TestCoroutineContext.assertUnhandledException] or
* [TestCoroutineContext.assertExceptions], the list of unhandled exceptions will have been cleared and this method will
* not throw an [AssertionError].
*
* @param testContext The provided [TestCoroutineContext]. If not specified, a default [TestCoroutineContext] will be
* provided instead.
* @param testBody The code of the unit-test.
*/
public fun withTestContext(testContext: TestCoroutineContext = TestCoroutineContext(), testBody: TestCoroutineContext.() -> Unit) {
with (testContext) {
testBody()

if (!exceptions.all { it is CancellationException }) {
throw AssertionError("Coroutine encountered unhandled exceptions:\n${exceptions}")
}
}
}

/* Some helper functions */
public fun TestCoroutineContext.launch(
start: CoroutineStart = CoroutineStart.DEFAULT,
parent: Job? = null,
onCompletion: CompletionHandler? = null,
block: suspend CoroutineScope.() -> Unit
) = launch(this, start, parent, onCompletion, block)

public fun <T> TestCoroutineContext.async(
start: CoroutineStart = CoroutineStart.DEFAULT,
parent: Job? = null,
onCompletion: CompletionHandler? = null,
block: suspend CoroutineScope.() -> T

) = async(this, start, parent, onCompletion, block)

public fun <T> TestCoroutineContext.runBlocking(
block: suspend CoroutineScope.() -> T
) = runBlocking(this, block)
Loading

0 comments on commit 289f3ba

Please sign in to comment.