From a899a52f0c0322f46dfc9f6f27c42327e3e9e436 Mon Sep 17 00:00:00 2001 From: Ax333l Date: Sun, 10 Nov 2024 17:34:00 +0100 Subject: [PATCH] parallelism PoC --- .../app/revanced/patcher/Fingerprint.kt | 159 +++++++++---- .../kotlin/app/revanced/patcher/Patcher.kt | 211 ++++++++++-------- .../patcher/patch/BytecodePatchContext.kt | 59 +++-- .../app/revanced/patcher/patch/Patch.kt | 44 ++-- .../app/revanced/patcher/util/ClassMerger.kt | 6 +- .../revanced/patcher/util/MethodNavigator.kt | 2 +- .../revanced/patcher/util/ProxyClassList.kt | 5 +- 7 files changed, 312 insertions(+), 174 deletions(-) diff --git a/src/main/kotlin/app/revanced/patcher/Fingerprint.kt b/src/main/kotlin/app/revanced/patcher/Fingerprint.kt index b329d5f8..05c74bb1 100644 --- a/src/main/kotlin/app/revanced/patcher/Fingerprint.kt +++ b/src/main/kotlin/app/revanced/patcher/Fingerprint.kt @@ -13,6 +13,93 @@ import com.android.tools.smali.dexlib2.iface.Method import com.android.tools.smali.dexlib2.iface.instruction.ReferenceInstruction import com.android.tools.smali.dexlib2.iface.reference.StringReference import com.android.tools.smali.dexlib2.util.MethodUtil +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.flatMapMerge +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.job +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.selects.select +import kotlinx.coroutines.withContext + +/* +suspend inline fun Iterable.concurrentFirstNotNullOfOrNull(crossinline transform: (T) -> R?) = + asFlow().flatMapMerge { value -> + flow { transform(value)?.let { emit(it) } } + }.firstOrNull() +*/ +/* +suspend inline fun Iterable.concurrentFind(crossinline predicate: (T) -> Boolean) = + asFlow().flatMapMerge { value -> flow { if (predicate(value)) emit(value) } }.firstOrNull() +*/ +internal fun List<*>.chunks(count: Int): List> { + if (size <= count) return listOf(0 to lastIndex) + + val chunkSize = size / count + val indices = MutableList(count) { (it*chunkSize) + 1 to (it + 1)*chunkSize } + indices[0] = 0 to indices[0].second + indices[indices.lastIndex] = indices[indices.lastIndex].first to lastIndex + return indices +} +internal suspend inline fun List.concurrentFind(crossinline predicate: (T) -> Boolean): T? = coroutineScope { + val cpus = Runtime.getRuntime().availableProcessors() + val completableDeferred = CompletableDeferred(parent = coroutineContext.job) + val jobs = chunks(cpus).map { (start, end) -> + launch(Dispatchers.Default) { + var i = start + while (i <= end) { + val element = this@concurrentFind[i] + if (predicate(element)) { + completableDeferred.complete(element) + return@launch + } + i++ + } + } + } + val notFoundJob = launch { + jobs.joinAll() + completableDeferred.complete(null) + } + + val result = completableDeferred.await() + jobs.forEach(Job::cancel) + notFoundJob.cancel() + result +} + +internal suspend inline fun List.concurrentFirstNotNullOfOrNull(crossinline transform: (T) -> R?): R? = coroutineScope { + val cpus = Runtime.getRuntime().availableProcessors() + val completableDeferred = CompletableDeferred(parent = coroutineContext.job) + val jobs = chunks(cpus).map { (start, end) -> + launch(Dispatchers.Default) { + var i = start + while (i <= end) { + val element = this@concurrentFirstNotNullOfOrNull[i] + transform(element)?.let { value -> + completableDeferred.complete(value) + return@launch + } + i++ + } + } + } + val notFoundJob = launch { + jobs.joinAll() + completableDeferred.complete(null) + } + + val result = completableDeferred.await() + jobs.forEach(Job::cancel) + notFoundJob.cancel() + result +} /** * A fingerprint for a method. A fingerprint is a partial description of a method. @@ -51,9 +138,11 @@ class Fingerprint internal constructor( /** * The match for this [Fingerprint]. Null if unmatched. */ + /* context(BytecodePatchContext) private val matchOrNull: Match? get() = matchOrNull() + */ /** * Match using [BytecodePatchContext.lookupMaps]. @@ -69,10 +158,10 @@ class Fingerprint internal constructor( * @return The [Match] if a match was found or if the fingerprint is already matched to a method, null otherwise. */ context(BytecodePatchContext) - internal fun matchOrNull(): Match? { + internal suspend fun matchOrNull(): Match? { if (_matchOrNull != null) return _matchOrNull - var match = strings?.mapNotNull { + val match = strings?.mapNotNull { lookupMaps.methodsByStrings[it] }?.minByOrNull { it.size }?.let { methodClasses -> methodClasses.forEach { (classDef, method) -> @@ -84,14 +173,9 @@ class Fingerprint internal constructor( } if (match != null) return match - synchronized(classes) { - classes.forEach { classDef -> - match = matchOrNull(classDef) - if (match != null) return match - } + return withContext(Dispatchers.Default) { + classes.concurrentFirstNotNullOfOrNull { matchOrNull(it) } } - - return null } /** @@ -122,7 +206,7 @@ class Fingerprint internal constructor( * @return The [Match] if a match was found or if the fingerprint is already matched to a method, null otherwise. */ context(BytecodePatchContext) - fun matchOrNull( + suspend fun matchOrNull( method: Method, ) = matchOrNull(method, classBy { method.definingClass == it.type }!!.immutableClass) @@ -184,7 +268,8 @@ class Fingerprint internal constructor( return@forEachIndexed } - val string = ((instruction as ReferenceInstruction).reference as StringReference).string + val string = + ((instruction as ReferenceInstruction).reference as StringReference).string val index = stringsList.indexOfFirst(string::contains) if (index == -1) return@forEachIndexed @@ -261,8 +346,7 @@ class Fingerprint internal constructor( * @throws PatchException If the [Fingerprint] has not been matched. */ context(BytecodePatchContext) - private val match - get() = matchOrNull ?: throw exception + private suspend fun match() = matchOrNull() ?: throw exception /** * Match using a [ClassDef]. @@ -285,7 +369,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - fun match( + suspend fun match( method: Method, ) = matchOrNull(method) ?: throw exception @@ -307,15 +391,13 @@ class Fingerprint internal constructor( * The class the matching method is a member of. */ context(BytecodePatchContext) - val originalClassDefOrNull - get() = matchOrNull?.originalClassDef + suspend fun originalClassDefOrNull() = matchOrNull()?.originalClassDef /** * The matching method. */ context(BytecodePatchContext) - val originalMethodOrNull - get() = matchOrNull?.originalMethod + suspend fun originalMethodOrNull() = matchOrNull()?.originalMethod /** * The mutable version of [originalClassDefOrNull]. @@ -324,8 +406,7 @@ class Fingerprint internal constructor( * Use [originalClassDefOrNull] if mutable access is not required. */ context(BytecodePatchContext) - val classDefOrNull - get() = matchOrNull?.classDef + suspend fun classDefOrNull() = matchOrNull()?.classDef /** * The mutable version of [originalMethodOrNull]. @@ -334,22 +415,19 @@ class Fingerprint internal constructor( * Use [originalMethodOrNull] if mutable access is not required. */ context(BytecodePatchContext) - val methodOrNull - get() = matchOrNull?.method + suspend fun methodOrNull() = matchOrNull()?.method /** * The match for the opcode pattern. */ context(BytecodePatchContext) - val patternMatchOrNull - get() = matchOrNull?.patternMatch + suspend fun patternMatchOrNull() = matchOrNull()?.patternMatch /** * The matches for the strings. */ context(BytecodePatchContext) - val stringMatchesOrNull - get() = matchOrNull?.stringMatches + suspend fun stringMatchesOrNull() = matchOrNull()?.stringMatches /** * The class the matching method is a member of. @@ -357,8 +435,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - val originalClassDef - get() = match.originalClassDef + suspend fun originalClassDef() = match().originalClassDef /** * The matching method. @@ -366,8 +443,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - val originalMethod - get() = match.originalMethod + suspend fun originalMethod() = match().originalMethod /** * The mutable version of [originalClassDef]. @@ -378,8 +454,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - val classDef - get() = match.classDef + suspend fun classDef() = match().classDef /** * The mutable version of [originalMethod]. @@ -390,8 +465,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - val method - get() = match.method + suspend fun method() = match().method /** * The match for the opcode pattern. @@ -399,8 +473,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - val patternMatch - get() = match.patternMatch + suspend fun patternMatch() = match().patternMatch /** * The matches for the strings. @@ -408,8 +481,7 @@ class Fingerprint internal constructor( * @throws PatchException If the fingerprint has not been matched. */ context(BytecodePatchContext) - val stringMatches - get() = match.stringMatches + suspend fun stringMatches() = match().stringMatches } /** @@ -433,7 +505,7 @@ class Match internal constructor( * Accessing this property allocates a [ClassProxy]. * Use [originalClassDef] if mutable access is not required. */ - val classDef by lazy { proxy(originalClassDef).mutableClass } + val classDef by lazy { syncProxy(originalClassDef).mutableClass } /** * The mutable version of [originalMethod]. @@ -441,7 +513,14 @@ class Match internal constructor( * Accessing this property allocates a [ClassProxy]. * Use [originalMethod] if mutable access is not required. */ - val method by lazy { classDef.methods.first { MethodUtil.methodSignaturesMatch(it, originalMethod) } } + val method by lazy { + classDef.methods.first { + MethodUtil.methodSignaturesMatch( + it, + originalMethod + ) + } + } /** * A match for an opcode pattern. diff --git a/src/main/kotlin/app/revanced/patcher/Patcher.kt b/src/main/kotlin/app/revanced/patcher/Patcher.kt index 1cfaa6d5..1a16904e 100644 --- a/src/main/kotlin/app/revanced/patcher/Patcher.kt +++ b/src/main/kotlin/app/revanced/patcher/Patcher.kt @@ -3,9 +3,12 @@ package app.revanced.patcher import app.revanced.patcher.patch.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import java.io.Closeable import java.util.concurrent.ConcurrentHashMap import java.util.logging.Logger +import kotlin.time.measureTime /** * A Patcher. @@ -58,123 +61,149 @@ class Patcher(private val config: PatcherConfig) : Closeable { * * @return A flow of [PatchResult]s. */ + @OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class) operator fun invoke() = channelFlow { - // Prevent decoding the app manifest twice if it is not needed. - if (config.resourceMode != ResourcePatchContext.ResourceMode.NONE) { - context.resourceContext.decodeResources(config.resourceMode) - } + coroutineScope { + launch(Dispatchers.Default) { + // Prevent decoding the app manifest twice if it is not needed. + if (config.resourceMode != ResourcePatchContext.ResourceMode.NONE) { + context.resourceContext.decodeResources(config.resourceMode) + } + } - logger.info("Initializing lookup maps") + launch(Dispatchers.Default) { + logger.info("Initializing lookup maps") - // Accessing the lazy lookup maps to initialize them. - context.bytecodeContext.lookupMaps + // Accessing the lazy lookup maps to initialize them. + context.bytecodeContext.lookupMaps + } + } logger.info("Executing patches") + // Dispatcher.Default.limitedParallelism(1) + newSingleThreadContext("Patcher").use { dispatcher -> - val executedPatches = ConcurrentHashMap, Deferred>() + val executedPatches = HashMap, Deferred>() + val bytecodeLock = Mutex() - suspend operator fun Patch<*>.invoke(): Deferred { - val patch = this + suspend fun Patch<*>.runBlock(block: suspend () -> Unit) { + if (this is BytecodePatch) bytecodeLock.withLock { block() } else withContext( + Dispatchers.IO + ) { block() } + } - // If the patch was executed before or failed, return it's the result. - executedPatches[patch]?.let { deferredPatchResult -> - val patchResult = deferredPatchResult.await() + suspend operator fun Patch<*>.invoke(): Deferred { + val patch = this - patchResult.exception ?: return deferredPatchResult + // If the patch was executed before or failed, return it's the result. + executedPatches[patch]?.let { deferredPatchResult -> + val patchResult = deferredPatchResult.await() - return CompletableDeferred(PatchResult(patch, PatchException("The patch '$patch' failed previously"))) - } + patchResult.exception ?: return deferredPatchResult - return async(Dispatchers.IO) { - // Recursively execute all dependency patches. - val dependenciesResult = coroutineScope { - val dependenciesJobs = dependencies.map { dependency -> - async(Dispatchers.IO) { - dependency().await().exception?.let { exception -> - PatchResult( - patch, - PatchException( - "The patch \"$patch\" depends on \"$dependency\", which raised an exception:\n" + - exception.stackTraceToString(), - ), - ) + return CompletableDeferred( + PatchResult( + patch, + PatchException("The patch '$patch' failed previously") + ) + ) + } + + return async(dispatcher) { + // Recursively execute all dependency patches. + val dependenciesResult = coroutineScope { + val dependenciesJobs = dependencies.map { dependency -> + async(dispatcher) { + dependency().await().exception?.let { exception -> + PatchResult( + patch, + PatchException( + "The patch \"$patch\" depends on \"$dependency\", which raised an exception:\n" + + exception.stackTraceToString(), + ), + ) + } } } - } - dependenciesJobs.awaitAll().firstOrNull { result -> result != null }?.let { - dependenciesJobs.forEach(Deferred<*>::cancel) + dependenciesJobs.awaitAll().firstOrNull { result -> result != null }?.let { + dependenciesJobs.forEach(Deferred<*>::cancel) - return@coroutineScope it + return@coroutineScope it + } } - } - - if (dependenciesResult != null) { - return@async dependenciesResult - } - - // Execute the patch. - try { - execute(context) - PatchResult(patch) - } catch (exception: PatchException) { - PatchResult(patch, exception) - } catch (exception: Exception) { - PatchResult(patch, PatchException(exception)) - } - }.also { executedPatches[patch] = it } - } + if (dependenciesResult != null) { + return@async dependenciesResult + } - coroutineScope { - context.executablePatches.sortedBy { it.name }.map { patch -> - launch(Dispatchers.IO) { - val patchResult = patch().await() + // Execute the patch. + try { + runBlock { execute(context) } - // If an exception occurred or the patch has no finalize block, emit the result. - if (patchResult.exception != null || patch.finalizeBlock == null) { - send(patchResult) + PatchResult(patch) + } catch (exception: PatchException) { + PatchResult(patch, exception) + } catch (exception: Exception) { + PatchResult(patch, PatchException(exception)) } - } - }.joinAll() - } + }.also { executedPatches[patch] = it } + } - val succeededPatchesWithFinalizeBlock = executedPatches.values.map { it.await() }.filter { - it.exception == null && it.patch.finalizeBlock != null - } + val time = measureTime { + coroutineScope { + context.executablePatches.sortedBy { it.name }.map { patch -> + launch(dispatcher) { + val patchResult = patch().await() - coroutineScope { - succeededPatchesWithFinalizeBlock.asReversed().map { executionResult -> - launch(Dispatchers.IO) { - val patch = executionResult.patch - - val result = - try { - patch.finalize(context) - - executionResult - } catch (exception: PatchException) { - PatchResult(patch, exception) - } catch (exception: Exception) { - PatchResult(patch, PatchException(exception)) + // If an exception occurred or the patch has no finalize block, emit the result. + if (patchResult.exception != null || patch.finalizeBlock == null) { + send(patchResult) + } } + }.joinAll() + } - if (result.exception != null) { - send( - PatchResult( - patch, - PatchException( - "The patch \"$patch\" raised an exception during finalization:\n" + - result.exception.stackTraceToString(), - result.exception, - ), - ), - ) - } else if (patch in context.executablePatches) { - send(result) + val succeededPatchesWithFinalizeBlock = + executedPatches.values.map { it.await() }.filter { + it.exception == null && it.patch.finalizeBlock != null } + + coroutineScope { + succeededPatchesWithFinalizeBlock.asReversed().map { executionResult -> + launch(dispatcher) { + val patch = executionResult.patch + + val result = + try { + patch.runBlock { patch.finalize(context) } + + executionResult + } catch (exception: PatchException) { + PatchResult(patch, exception) + } catch (exception: Exception) { + PatchResult(patch, PatchException(exception)) + } + + if (result.exception != null) { + send( + PatchResult( + patch, + PatchException( + "The patch \"$patch\" raised an exception during finalization:\n" + + result.exception.stackTraceToString(), + result.exception, + ), + ), + ) + } else if (patch in context.executablePatches) { + send(result) + } + } + }.joinAll() } - }.joinAll() + } + logger.info("Patching completed in $time") } } diff --git a/src/main/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt b/src/main/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt index b243153c..1c4fff18 100644 --- a/src/main/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt +++ b/src/main/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt @@ -3,6 +3,8 @@ package app.revanced.patcher.patch import app.revanced.patcher.InternalApi import app.revanced.patcher.PatcherConfig import app.revanced.patcher.PatcherResult +import app.revanced.patcher.concurrentFind +import app.revanced.patcher.concurrentFirstNotNullOfOrNull import app.revanced.patcher.extensions.InstructionExtensions.instructionsOrNull import app.revanced.patcher.util.ClassMerger.merge import app.revanced.patcher.util.MethodNavigator @@ -16,6 +18,8 @@ import com.android.tools.smali.dexlib2.iface.Method import com.android.tools.smali.dexlib2.iface.instruction.ReferenceInstruction import com.android.tools.smali.dexlib2.iface.reference.MethodReference import com.android.tools.smali.dexlib2.iface.reference.StringReference +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import lanchon.multidexlib2.BasicDexFileNamer import lanchon.multidexlib2.DexIO import lanchon.multidexlib2.MultiDexIO @@ -65,9 +69,15 @@ class BytecodePatchContext internal constructor(private val config: PatcherConfi * * @param bytecodePatch The [BytecodePatch] to merge the extension of. */ - internal fun mergeExtension(bytecodePatch: BytecodePatch) { + internal suspend fun mergeExtension(bytecodePatch: BytecodePatch) { bytecodePatch.extensionInputStream?.get()?.use { extensionStream -> - RawDexIO.readRawDexFile(extensionStream, 0, null).classes.forEach { classDef -> + withContext(Dispatchers.IO) { + RawDexIO.readRawDexFile( + extensionStream, + 0, + null + ) + }.classes.forEach { classDef -> val existingClass = lookupMaps.classesByType[classDef.type] ?: run { logger.fine { "Adding class \"$classDef\"" } @@ -98,18 +108,21 @@ class BytecodePatchContext internal constructor(private val config: PatcherConfi * @param predicate A predicate to match the class. * @return A proxy for the first class that matches the predicate. */ - fun classBy(predicate: (ClassDef) -> Boolean): ClassProxy? { - val proxy = synchronized(classes.proxyPool) { classes.proxyPool.find { predicate(it.immutableClass) } } - if (proxy != null) return proxy + suspend fun classBy(predicate: (ClassDef) -> Boolean): ClassProxy? = + withContext(Dispatchers.Default) { + // val proxy = synchronized(classes.proxyPool) { classes.proxyPool.find { predicate(it.immutableClass) } } + val proxy = classes.proxyPool.concurrentFind { predicate(it.immutableClass) } + if (proxy != null) return@withContext proxy + + val classDef = classes.concurrentFind { predicate(it) } + if (classDef != null) { + // return proxy(classDef) + return@withContext ClassProxy(classDef).also { classes.proxyPool.add(it) } + } - val classDef = synchronized(classes) { classes.find(predicate) } - if (classDef != null) { - return proxy(classDef) + return@withContext null } - return null - } - /** * Proxy the class to allow mutation. * @@ -117,9 +130,25 @@ class BytecodePatchContext internal constructor(private val config: PatcherConfi * * @return A proxy for the class. */ - fun proxy(classDef: ClassDef) = synchronized(classes.proxyPool) { + suspend fun proxy(classDef: ClassDef) = withContext(Dispatchers.Default) { + classes.proxyPool.concurrentFind { it.immutableClass.type == classDef.type } + ?: ClassProxy(classDef).also { classes.proxyPool.add(it) } + } + + internal fun syncProxy(classDef: ClassDef) = classes.proxyPool.find { it.immutableClass.type == classDef.type } ?: ClassProxy(classDef).also { classes.proxyPool.add(it) } + + internal fun classBySync(predicate: (ClassDef) -> Boolean): ClassProxy? { + val proxy = classes.proxyPool.find { predicate(it.immutableClass) } + if (proxy != null) return proxy + + val classDef = classes.find(predicate) + if (classDef != null) { + return ClassProxy(classDef).also { classes.proxyPool.add(it) } + } + + return null } /** @@ -155,7 +184,8 @@ class BytecodePatchContext internal constructor(private val config: PatcherConfi BasicDexFileNamer(), object : DexFile { override fun getClasses() = - this@BytecodePatchContext.classes.also(ProxyClassList::replaceClasses).toSet() + this@BytecodePatchContext.classes.also(ProxyClassList::replaceClasses) + .toSet() override fun getOpcodes() = this@BytecodePatchContext.opcodes }, @@ -197,7 +227,8 @@ class BytecodePatchContext internal constructor(private val config: PatcherConfi return@instructions } - val string = ((instruction as ReferenceInstruction).reference as StringReference).string + val string = + ((instruction as ReferenceInstruction).reference as StringReference).string methodsByStrings[string] = methodClassPair } diff --git a/src/main/kotlin/app/revanced/patcher/patch/Patch.kt b/src/main/kotlin/app/revanced/patcher/patch/Patch.kt index 8f0dc838..5acd82a5 100644 --- a/src/main/kotlin/app/revanced/patcher/patch/Patch.kt +++ b/src/main/kotlin/app/revanced/patcher/patch/Patch.kt @@ -45,10 +45,10 @@ sealed class Patch>( val dependencies: Set>, val compatiblePackages: Set?, options: Set>, - private val executeBlock: (C) -> Unit, + private val executeBlock: suspend (C) -> Unit, // Must be internal and nullable, so that Patcher.invoke can check, // if a patch has a finalizing block in order to not emit it twice. - internal var finalizeBlock: ((C) -> Unit)?, + internal var finalizeBlock: (suspend (C) -> Unit)?, ) { /** * The options of the patch. @@ -61,14 +61,14 @@ sealed class Patch>( * * @param context The [PatcherContext] to get the [PatchContext] from to execute the patch with. */ - internal abstract fun execute(context: PatcherContext) + internal abstract suspend fun execute(context: PatcherContext) /** * Calls the execution block of the patch. * * @param context The [PatchContext] to execute the patch with. */ - fun execute(context: C) = executeBlock(context) + suspend fun execute(context: C) = executeBlock(context) /** * Calls the finalizing block of the patch. @@ -76,14 +76,14 @@ sealed class Patch>( * * @param context The [PatcherContext] to get the [PatchContext] from to finalize the patch with. */ - internal abstract fun finalize(context: PatcherContext) + internal abstract suspend fun finalize(context: PatcherContext) /** * Calls the finalizing block of the patch. * * @param context The [PatchContext] to finalize the patch with. */ - fun finalize(context: C) { + suspend fun finalize(context: C) { finalizeBlock?.invoke(context) } @@ -142,8 +142,8 @@ class BytecodePatch internal constructor( dependencies: Set>, options: Set>, val extensionInputStream: Supplier?, - executeBlock: (BytecodePatchContext) -> Unit, - finalizeBlock: ((BytecodePatchContext) -> Unit)?, + executeBlock: suspend (BytecodePatchContext) -> Unit, + finalizeBlock: (suspend (BytecodePatchContext) -> Unit)?, ) : Patch( name, description, @@ -154,12 +154,12 @@ class BytecodePatch internal constructor( executeBlock, finalizeBlock, ) { - override fun execute(context: PatcherContext) = with(context.bytecodeContext) { + override suspend fun execute(context: PatcherContext) = with(context.bytecodeContext) { mergeExtension(this@BytecodePatch) execute(this) } - override fun finalize(context: PatcherContext) = finalize(context.bytecodeContext) + override suspend fun finalize(context: PatcherContext) = finalize(context.bytecodeContext) override fun toString() = name ?: "BytecodePatch" } @@ -188,8 +188,8 @@ class RawResourcePatch internal constructor( compatiblePackages: Set?, dependencies: Set>, options: Set>, - executeBlock: (ResourcePatchContext) -> Unit, - finalizeBlock: ((ResourcePatchContext) -> Unit)?, + executeBlock: suspend (ResourcePatchContext) -> Unit, + finalizeBlock: (suspend (ResourcePatchContext) -> Unit)?, ) : Patch( name, description, @@ -200,9 +200,9 @@ class RawResourcePatch internal constructor( executeBlock, finalizeBlock, ) { - override fun execute(context: PatcherContext) = execute(context.resourceContext) + override suspend fun execute(context: PatcherContext) = execute(context.resourceContext) - override fun finalize(context: PatcherContext) = finalize(context.resourceContext) + override suspend fun finalize(context: PatcherContext) = finalize(context.resourceContext) override fun toString() = name ?: "RawResourcePatch" } @@ -231,8 +231,8 @@ class ResourcePatch internal constructor( compatiblePackages: Set?, dependencies: Set>, options: Set>, - executeBlock: (ResourcePatchContext) -> Unit, - finalizeBlock: ((ResourcePatchContext) -> Unit)?, + executeBlock: suspend (ResourcePatchContext) -> Unit, + finalizeBlock: (suspend (ResourcePatchContext) -> Unit)?, ) : Patch( name, description, @@ -243,9 +243,9 @@ class ResourcePatch internal constructor( executeBlock, finalizeBlock, ) { - override fun execute(context: PatcherContext) = execute(context.resourceContext) + override suspend fun execute(context: PatcherContext) = execute(context.resourceContext) - override fun finalize(context: PatcherContext) = finalize(context.resourceContext) + override suspend fun finalize(context: PatcherContext) = finalize(context.resourceContext) override fun toString() = name ?: "ResourcePatch" } @@ -277,8 +277,8 @@ sealed class PatchBuilder>( protected var dependencies = mutableSetOf>() protected val options = mutableSetOf>() - protected var executionBlock: ((C) -> Unit) = { } - protected var finalizeBlock: ((C) -> Unit)? = null + protected var executionBlock: (suspend (C) -> Unit) = { } + protected var finalizeBlock: (suspend (C) -> Unit)? = null /** * Add an option to the patch. @@ -337,7 +337,7 @@ sealed class PatchBuilder>( * * @param block The execution block of the patch. */ - fun execute(block: C.() -> Unit) { + fun execute(block: suspend C.() -> Unit) { executionBlock = block } @@ -346,7 +346,7 @@ sealed class PatchBuilder>( * * @param block The finalizing block of the patch. */ - fun finalize(block: C.() -> Unit) { + fun finalize(block: suspend C.() -> Unit) { finalizeBlock = block } diff --git a/src/main/kotlin/app/revanced/patcher/util/ClassMerger.kt b/src/main/kotlin/app/revanced/patcher/util/ClassMerger.kt index d9a3a218..997f974e 100644 --- a/src/main/kotlin/app/revanced/patcher/util/ClassMerger.kt +++ b/src/main/kotlin/app/revanced/patcher/util/ClassMerger.kt @@ -33,7 +33,7 @@ internal object ClassMerger { * @param context The context to traverse the class hierarchy in. * @return The merged class or the original class if no merge was needed. */ - fun ClassDef.merge( + suspend fun ClassDef.merge( otherClass: ClassDef, context: BytecodePatchContext, ) = this @@ -92,7 +92,7 @@ internal object ClassMerger { * @param reference The class to check the [AccessFlags] of. * @param context The context to traverse the class hierarchy in. */ - private fun ClassDef.publicize( + private suspend fun ClassDef.publicize( reference: ClassDef, context: BytecodePatchContext, ) = if (reference.accessFlags.isPublic() && !accessFlags.isPublic()) { @@ -174,7 +174,7 @@ internal object ClassMerger { * @param targetClass the class to start traversing the class hierarchy from * @param callback function that is called for every class in the hierarchy */ - fun BytecodePatchContext.traverseClassHierarchy( + suspend fun BytecodePatchContext.traverseClassHierarchy( targetClass: MutableClass, callback: MutableClass.() -> Unit, ) { diff --git a/src/main/kotlin/app/revanced/patcher/util/MethodNavigator.kt b/src/main/kotlin/app/revanced/patcher/util/MethodNavigator.kt index d894e9e7..4c96946c 100644 --- a/src/main/kotlin/app/revanced/patcher/util/MethodNavigator.kt +++ b/src/main/kotlin/app/revanced/patcher/util/MethodNavigator.kt @@ -80,7 +80,7 @@ class MethodNavigator internal constructor( * * @return The last navigated method mutably. */ - fun stop() = classBy(matchesCurrentMethodReferenceDefiningClass)!!.mutableClass.firstMethodBySignature + fun stop() = classBySync(matchesCurrentMethodReferenceDefiningClass)!!.mutableClass.firstMethodBySignature as MutableMethod /** diff --git a/src/main/kotlin/app/revanced/patcher/util/ProxyClassList.kt b/src/main/kotlin/app/revanced/patcher/util/ProxyClassList.kt index e89da661..ac19ce85 100644 --- a/src/main/kotlin/app/revanced/patcher/util/ProxyClassList.kt +++ b/src/main/kotlin/app/revanced/patcher/util/ProxyClassList.kt @@ -2,7 +2,6 @@ package app.revanced.patcher.util import app.revanced.patcher.util.proxy.ClassProxy import com.android.tools.smali.dexlib2.iface.ClassDef -import java.util.* /** * A list of classes and proxies. @@ -11,8 +10,8 @@ import java.util.* */ class ProxyClassList internal constructor( classes: MutableList, -) : MutableList by Collections.synchronizedList(classes) { - internal val proxyPool = Collections.synchronizedList(mutableListOf()) +) : MutableList by classes { + internal val proxyPool = mutableListOf() /** * Replace all classes with their mutated versions.