Skip to content

Commit

Permalink
NRL Operation must be Recoverable
Browse files Browse the repository at this point in the history
  • Loading branch information
zuevmaxim committed Oct 5, 2021
1 parent 232edd9 commit 1a485d5
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ import org.objectweb.asm.commons.GeneratorAdapter
import org.objectweb.asm.commons.Method
import kotlin.reflect.jvm.javaMethod

internal open class CrashEnabledVisitor(cv: ClassVisitor, testClass: Class<*>, initial: Boolean = true) :
internal open class CrashEnabledVisitor(cv: ClassVisitor, initial: Boolean = true) :
ClassVisitor(ASM_API, cv) {
private val superClassNames = testClass.superClassNames()
var shouldTransform = initial
private set
var name: String? = null
Expand All @@ -48,12 +47,6 @@ internal open class CrashEnabledVisitor(cv: ClassVisitor, testClass: Class<*>, i
) {
super.visit(version, access, name, signature, superName, interfaces)
this.name = name
if (name in superClassNames || name !== null &&
name.startsWith("org.jetbrains.kotlinx.lincheck.") &&
!name.startsWith("org.jetbrains.kotlinx.lincheck.test.")
) {
shouldTransform = false
}
}

override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor {
Expand All @@ -69,10 +62,7 @@ internal open class CrashEnabledVisitor(cv: ClassVisitor, testClass: Class<*>, i
}
}

internal class CrashTransformer(
cv: ClassVisitor,
testClass: Class<*>
) : CrashEnabledVisitor(cv, testClass) {
internal class CrashTransformer(cv: ClassVisitor) : CrashEnabledVisitor(cv) {
override fun visitMethod(
access: Int,
name: String?,
Expand Down Expand Up @@ -193,13 +183,3 @@ internal class CrashRethrowTransformer(cv: ClassVisitor) : ClassVisitor(ASM_API,
}
}
}

private fun Class<*>.superClassNames(): List<String> {
val result = mutableListOf<String>()
var clazz: Class<*>? = this
while (clazz !== null) {
result.add(Type.getInternalName(clazz))
clazz = clazz.superclass
}
return result
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

package org.jetbrains.kotlinx.lincheck.nvm

import org.jetbrains.kotlinx.lincheck.annotations.Operation
import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
import org.jetbrains.kotlinx.lincheck.execution.ExecutionScenario
import org.jetbrains.kotlinx.lincheck.verifier.Verifier
import org.jetbrains.kotlinx.lincheck.verifier.linearizability.LinearizabilityVerifier
Expand Down Expand Up @@ -76,8 +78,8 @@ private object RecoverExecutionCallback : ExecutionCallback {
internal enum class StrategyRecoveryOptions {
STRESS, MANAGED;

fun createCrashTransformer(cv: ClassVisitor, clazz: Class<*>): ClassVisitor = when (this) {
STRESS -> CrashRethrowTransformer(CrashTransformer(cv, clazz))
fun createCrashTransformer(cv: ClassVisitor): ClassVisitor = when (this) {
STRESS -> CrashRethrowTransformer(CrashTransformer(cv))
MANAGED -> CrashRethrowTransformer(cv) // add crashes in ManagedStrategyTransformer
}
}
Expand Down Expand Up @@ -111,6 +113,7 @@ interface RecoverabilityModel {
fun defaultExpectedCrashes(): Int
fun createExecutionCallback(): ExecutionCallback
fun createProbabilityModel(): ProbabilityModel
fun checkTestClass(testClass: Class<*>) {}
val awaitSystemCrashBeforeThrow: Boolean
val verifierClass: Class<out Verifier>

Expand Down Expand Up @@ -151,10 +154,24 @@ private class NRLModel(
override fun createTransformer(cv: ClassVisitor, clazz: Class<*>): ClassVisitor {
var result: ClassVisitor = RecoverabilityTransformer(cv)
if (crashes) {
result = strategyRecoveryOptions.createCrashTransformer(result, clazz)
result = strategyRecoveryOptions.createCrashTransformer(result)
}
return result
}

override fun checkTestClass(testClass: Class<*>) {
var clazz: Class<*>? = testClass
while (clazz !== null) {
clazz.declaredMethods.forEach { method ->
val isOperation = method.isAnnotationPresent(Operation::class.java)
val isRecoverable = method.isAnnotationPresent(Recoverable::class.java)
require(!isOperation || isRecoverable) {
"Every operation must have a Recovery annotation, but ${method.name} operation in ${clazz!!.name} class is not Recoverable."
}
}
clazz = clazz.superclass
}
}
}

private open class DurableModel(val strategyRecoveryOptions: StrategyRecoveryOptions) : RecoverabilityModel {
Expand All @@ -171,7 +188,7 @@ private open class DurableModel(val strategyRecoveryOptions: StrategyRecoveryOpt
override val verifierClass: Class<out Verifier> get() = DurableLinearizabilityVerifier::class.java
override fun createTransformerWrapper(cv: ClassVisitor, clazz: Class<*>) = DurableRecoverAllGenerator(cv, clazz)
override fun createTransformer(cv: ClassVisitor, clazz: Class<*>): ClassVisitor =
strategyRecoveryOptions.createCrashTransformer(DurableOperationRecoverTransformer(cv, clazz), clazz)
strategyRecoveryOptions.createCrashTransformer(DurableOperationRecoverTransformer(cv, clazz))
}

private class DetectableExecutionModel(strategyRecoveryOptions: StrategyRecoveryOptions) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ internal class SwitchesAndCrashesModelCheckingStrategy(
override fun createRoot(): InterleavingTreeNode = ThreadChoosingNodeWithCrashes((0 until nThreads).toList())

override fun createTransformer(cv: ClassVisitor): ClassVisitor {
val visitor = CrashEnabledVisitor(cv, testClass, recoverModel.crashes)
val visitor = CrashEnabledVisitor(cv, recoverModel.crashes)
val recoverTransformer = recoverModel.createTransformer(visitor, testClass)
val managedTransformer = CrashesManagedStrategyTransformer(
recoverTransformer, tracePointConstructors, testCfg.guarantees, testCfg.eliminateLocalObjects,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ internal open class ParallelThreadsRunner(
override fun initialize() {
executionCallback.reset(scenario, recoverModel)
super.initialize()
recoverModel.checkTestClass(testClass)
testThreadExecutions = Array(scenario.threads) { t ->
TestThreadExecutionGenerator.create(this, t, scenario.parallelExecution[t], completions[t], scenario.hasSuspendableActors(), recoverModel.createActorCrashHandlerGenerator())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.jetbrains.kotlinx.lincheck.test.nvm

import org.jetbrains.kotlinx.lincheck.LinChecker
import org.jetbrains.kotlinx.lincheck.annotations.Operation
import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
import org.jetbrains.kotlinx.lincheck.nvm.Recover
import org.jetbrains.kotlinx.lincheck.nvm.api.nonVolatile
import org.jetbrains.kotlinx.lincheck.strategy.stress.StressCTest
Expand All @@ -33,9 +34,11 @@ internal class PersistentTest {
private val x = nonVolatile(0)

@Operation
@Recoverable
fun read() = x.value

@Operation
@Recoverable
fun write(value: Int) {
x.value = value
x.flush()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.lincheck.test.transformation

import org.jetbrains.kotlinx.lincheck.LinChecker
import org.jetbrains.kotlinx.lincheck.annotations.Operation
import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
import org.jetbrains.kotlinx.lincheck.nvm.Recover
import org.jetbrains.kotlinx.lincheck.strategy.stress.StressCTest
import org.jetbrains.kotlinx.lincheck.verifier.VerifierState
Expand All @@ -32,6 +33,7 @@ import org.junit.Test
internal class LincheckClassCrashFreeTest : VerifierState() {

@Operation
@Recoverable
fun simple() = 42

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ internal class CrashInsertTest : VerifierState() {
private val c = NVMClass()

@Operation
@Recoverable
fun foo() = c.foo()
override fun extractState() = 4

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ private data class CrashPosition(val iActor: Int, val line: Int) : Comparable<Cr

class UniformDistributedCrashesTest {
@Test
fun testDurable() = test(/* 2 foo1 + 3 foo2 */37, Recover.DURABLE, SequentialCodeTest::class.java, 1)
fun testDurable() = test(/* 2 foo1 + 3 foo2 */42, Recover.DURABLE, SequentialCodeTest::class.java, 1)

@Test
@Ignore("This model works only with unbounded number of crashes. But this is uncomfortable to use & analyze.")
fun testDetectableExecution() =
test(/* 2 foo1 + 3 foo2 */37, Recover.DETECTABLE_EXECUTION, SequentialCodeTest::class.java, 5)
test(/* 2 foo1 + 3 foo2 */42, Recover.DETECTABLE_EXECUTION, SequentialCodeTest::class.java, 5)

private fun test(crashPoints: Int, model: Recover, testClass: Class<*>, expectedCrashes: Int) {
val n = 1_000_000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelChecki
import org.jetbrains.kotlinx.lincheck.strategy.stress.StressOptions
import org.jetbrains.kotlinx.lincheck.test.checkTraceHasNoLincheckEvents
import org.junit.Test
import java.lang.IllegalStateException
import java.lang.reflect.InvocationTargetException
import kotlin.reflect.KClass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ private const val THREADS_NUMBER = 3
internal interface Counter {
fun increment(threadId: Int)
fun get(threadId: Int): Int
fun incrementBefore(p: Int) {}
fun incrementRecover(p: Int) {}
}

/**
Expand All @@ -43,9 +45,13 @@ internal class CounterTest : AbstractNVMLincheckTest(Recover.NRL, THREADS_NUMBER
private val counter = NRLCounter(THREADS_NUMBER + 2)

@Operation
@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "incrementRecover")
override fun increment(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.increment(threadId)
override fun incrementBefore(p: Int) = counter.incrementBefore(p)
override fun incrementRecover(p: Int) = counter.incrementRecover(p)

@Operation
@Recoverable
override fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId)
}

Expand All @@ -65,22 +71,19 @@ internal open class NRLCounter(threadsCount: Int) : Counter {
protected val checkPointer = MutableList(threadsCount) { nonVolatile(0) }
protected val currentValue = MutableList(threadsCount) { nonVolatile(0) }

@Recoverable
override fun get(threadId: Int) = r.sumBy { it.read()!! }

@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "incrementRecover")
override fun increment(threadId: Int) = incrementImpl(threadId)

protected open fun incrementImpl(p: Int) {
r[p].write(1 + currentValue[p].value, p)
checkPointer[p].value = 1
}

protected open fun incrementRecover(p: Int) {
override fun incrementRecover(p: Int) {
if (checkPointer[p].value == 0) return incrementImpl(p)
}

protected open fun incrementBefore(p: Int) {
override fun incrementBefore(p: Int) {
currentValue[p].value = r[p].read()!!
checkPointer[p].value = 0
currentValue[p].flush()
Expand All @@ -92,10 +95,14 @@ internal abstract class CounterFailingTest :
AbstractNVMLincheckFailingTest(Recover.NRL, THREADS_NUMBER, SequentialCounter::class) {
protected abstract val counter: Counter

@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "incrementRecover")
@Operation
fun increment(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.increment(threadId)
fun incrementBefore(p: Int) = counter.incrementBefore(p)
fun incrementRecover(p: Int) = counter.incrementRecover(p)

@Operation
@Recoverable
fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,26 @@ private const val THREADS_NUMBER = 3
interface RWO<T> {
fun read(): T?
fun write(value: T, p: Int)
fun writeRecover(value: T, p: Int) {}
}

internal class ReadWriteObjectTest :
AbstractNVMLincheckTest(Recover.NRL, THREADS_NUMBER, SequentialReadWriteObject::class) {
private val rwo = NRLReadWriteObject<Pair<Int, Int>>(THREADS_NUMBER + 2)

@Recoverable
@Operation
fun read() = rwo.read()?.first

@Recoverable(recoverMethod = "writeRecover")
@Operation
fun write(
@Param(gen = ThreadIdGen::class) threadId: Int,
value: Int,
@Param(gen = OperationIdGen::class) operationId: Int
) = rwo.write(value to operationId, threadId)

fun writeRecover(threadId: Int, value: Int, operationId: Int) = rwo.writeRecover(value to operationId, threadId)
}

private val nullObject = Any()
Expand Down Expand Up @@ -85,10 +90,7 @@ internal open class NRLReadWriteObject<T>(threadsCount: Int, initial: T? = null)
// (state, value) for every thread
protected val state = MutableList(threadsCount) { nonVolatile(0 to null as T?) }

@Recoverable
override fun read(): T? = register.value

@Recoverable(recoverMethod = "writeRecover")
override fun write(value: T, p: Int) = writeImpl(value, p)

protected open fun writeImpl(value: T, p: Int) {
Expand All @@ -100,7 +102,7 @@ internal open class NRLReadWriteObject<T>(threadsCount: Int, initial: T? = null)
state[p].flush()
}

protected open fun writeRecover(value: T, p: Int) {
override fun writeRecover(value: T, p: Int) {
val (flag, current) = state[p].value
if (flag == 0 && current != value) return writeImpl(value, p)
else if (flag == 1 && current === register.value) return writeImpl(value, p)
Expand All @@ -114,14 +116,18 @@ internal abstract class ReadWriteObjectFailingTest :
protected abstract val rwo: RWO<Pair<Int, Int>>

@Operation
@Recoverable
fun read() = rwo.read()?.first

@Operation
@Recoverable(recoverMethod = "writeRecover")
fun write(
@Param(gen = ThreadIdGen::class) threadId: Int,
value: Int,
@Param(gen = OperationIdGen::class) operationId: Int
) = rwo.write(value to operationId, threadId)

fun writeRecover(threadId: Int, value: Int, operationId: Int) = rwo.writeRecover(value to operationId, threadId)
}

internal class SmallScenarioTest : ReadWriteObjectFailingTest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ private const val THREADS = 3
internal class RecoverableMutualExclusionWithPrimitivesTest : AbstractNVMLincheckTest(Recover.NRL, THREADS, SequentialCounter::class) {
private val counter = CounterWithLock(THREADS + 2, LockWithPrimitives(THREADS + 2))

@Recoverable(recoverMethod = "incRecover", beforeMethod = "incBefore")
@Operation
fun inc(@Param(gen = ThreadIdGen::class) threadId: Int): Int = counter.inc(threadId)
fun incRecover(threadId: Int) = counter.incRecover(threadId)
fun incBefore(threadId: Int) = counter.incBefore(threadId)

override fun testWithStressStrategy() {
println("${this::class.qualifiedName}:testWithStressStrategy test is ignored as no special atomic primitives available.")
Expand All @@ -61,7 +64,6 @@ internal class CounterWithLock(threads: Int, private val lock: DurableLock) {
private val cp = Array(threads) { nonVolatile(0) }
private val before = Array(threads) { nonVolatile(0) }

@Recoverable(recoverMethod = "incRecover", beforeMethod = "incBefore")
fun inc(threadId: Int) = incInternal(threadId)

private fun incInternal(threadId: Int): Int {
Expand Down Expand Up @@ -193,8 +195,11 @@ internal class LockWithPrimitives(threads: Int) : DurableLock {
internal class MutualExclusionFailingTest : AbstractNVMLincheckFailingTest(Recover.NRL, THREADS, SequentialCounter::class, false, DeadlockWithDumpFailure::class) {
private val counter = CounterWithLock(THREADS + 2, SimplestLockEver())

@Recoverable(recoverMethod = "incRecover", beforeMethod = "incBefore")
@Operation
fun inc(@Param(gen = ThreadIdGen::class) threadId: Int): Int = counter.inc(threadId)
fun incRecover(threadId: Int) = counter.incRecover(threadId)
fun incBefore(threadId: Int) = counter.incBefore(threadId)

override fun <O : Options<O, *>> O.customize() {
actorsBefore(0)
Expand Down
Loading

0 comments on commit 1a485d5

Please sign in to comment.