Skip to content

Commit

Permalink
improvement: cancel current request in batched functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Jul 18, 2023
1 parent f9bd535 commit ce3c761
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package scala.meta.internal.metals

import java.util.concurrent.CancellationException
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.util.control.NonFatal

import scala.meta.internal.async.ConcurrentQueue
Expand All @@ -22,10 +26,12 @@ final class BatchedFunction[A, B](
fn: Seq[A] => CancelableFuture[B],
functionId: String,
shouldLogQueue: Boolean = false,
default: Option[B] = None,
)(implicit ec: ExecutionContext)
extends (Seq[A] => Future[B])
with Function2[Seq[A], () => Unit, Future[B]]
with Pauseable {
val rand = new scala.util.Random

/**
* Call the function with the given arguments.
Expand Down Expand Up @@ -75,14 +81,33 @@ final class BatchedFunction[A, B](
}

def cancelAll(): Unit = {
queue.clear()
unlock()
val requests = ConcurrentQueue.pollAll(queue)
requests.foreach(_.result.complete(defaultResult))
cancelCurrent()
}

def cancelCurrent(): Unit = {
def withRetry(retry: Int): Unit =
currentCancelable.get() match {
case Some(cancelable) => cancelable.cancel()
case None =>
if (lock.get() != 0) {
Thread.sleep(100)
withRetry(retry - 1)
}
}
pause()
withRetry(1)
unpause()
}

def currentFuture(): Future[B] = {
current.get().future
}

private val currentCancelable =
new AtomicReference[Option[Cancelable]](None)

private val current = new AtomicReference(
CancelableFuture[B](
Future.failed(new NoSuchElementException("BatchedFunction")),
Expand All @@ -97,22 +122,25 @@ final class BatchedFunction[A, B](
callback: () => Unit,
)

private val lock = new AtomicBoolean()
private val lock = new AtomicInteger(0)

private def unlock(): Unit = {
lock.set(false)
lock.set(0)
currentCancelable.set(None)
if (!queue.isEmpty) {
runAcquire()
}
}
private def runAcquire(): Unit = {
if (!isPaused.get() && lock.compareAndSet(false, true)) {
runRelease()
val id = rand.nextInt()
if (!isPaused.get() && lock.compareAndSet(0, id)) {
runRelease(id: Int)
} else {
// Do nothing, the submitted arguments will be handled
// by a separate request.
}
}
private def runRelease(): Unit = {
private def runRelease(id: Int): Unit = {
// Pre-condition: lock is acquired.
// Pos-condition:
// - lock is released
Expand All @@ -125,40 +153,53 @@ final class BatchedFunction[A, B](
val args = requests.flatMap(_.arguments)
val callbacks = requests.map(_.callback)
val result = fn(args)
currentCancelable.set(Some(Cancelable { () =>
if (lock.compareAndSet(id, 0)) {
result.cancel()
requests.foreach(_.result.complete(defaultResult))
}
}))
this.current.set(result)
val resultF = for {
result <- result.future
_ <- Future {
callbacks.foreach(cb => cb())
}
_ <- Future { callbacks.foreach(cb => cb()) }
} yield result
resultF.onComplete { response =>
unlock()
requests.foreach(_.result.complete(response))
if (lock.compareAndSet(id, 0)) {
unlock()
requests.foreach(_.result.complete(response))
}
}
} else {
unlock()
}
} catch {
case NonFatal(e) =>
unlock()
requests.foreach(_.result.failure(e))
if (lock.compareAndSet(id, 0)) {
unlock()
requests.foreach(_.result.failure(e))
}
scribe.error(s"Unexpected error releasing buffered job", e)
}
}

def defaultResult: Try[B] =
default.map(Success(_)).getOrElse(Failure(new CancellationException))
}

object BatchedFunction {
def fromFuture[A, B](
fn: Seq[A] => Future[B],
functionId: String,
shouldLogQueue: Boolean = false,
default: Option[B] = None,
)(implicit
ec: ExecutionContext
): BatchedFunction[A, B] =
new BatchedFunction(
fn.andThen(CancelableFuture(_)),
functionId,
shouldLogQueue,
default,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class BuildServerConnection private (

private val ongoingRequests =
new MutableCancelable().addAll(initialConnection.cancelables)
private val ongoingCompilations = new MutableCancelable()

def version: String = _version.get()

Expand Down Expand Up @@ -137,7 +136,6 @@ class BuildServerConnection private (
def compile(params: CompileParams): CompletableFuture[CompileResult] = {
register(
server => server.buildTargetCompile(params),
isCompile = true,
onFail = Some(
(
new CompileResult(StatusCode.CANCELLED),
Expand Down Expand Up @@ -282,14 +280,9 @@ class BuildServerConnection private (
override def cancel(): Unit = {
if (cancelled.compareAndSet(false, true)) {
ongoingRequests.cancel()
ongoingCompilations.cancel()
}
}

def cancelCompilations(): Unit = {
ongoingCompilations.cancel()
}

private def askUser(): Future[BuildServerConnection.LauncherConnection] = {
if (config.askToReconnect) {
if (!reconnectNotification.isDismissed) {
Expand Down Expand Up @@ -337,7 +330,6 @@ class BuildServerConnection private (
private def register[T: ClassTag](
action: MetalsBuildServer => CompletableFuture[T],
onFail: => Option[(T, String)] = None,
isCompile: Boolean = false,
): CompletableFuture[T] = {

def runWithCanceling(
Expand All @@ -347,14 +339,12 @@ class BuildServerConnection private (
val cancelable = Cancelable { () =>
Try(resultFuture.cancel(true))
}
if (isCompile) ongoingCompilations.add(cancelable)
else ongoingRequests.add(cancelable)
ongoingRequests.add(cancelable)

val result = resultFuture.asScala

result.onComplete { _ =>
if (isCompile) ongoingCompilations.remove(cancelable)
else ongoingRequests.remove(cancelable)
ongoingRequests.remove(cancelable)
}
result
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ final class Compilations(
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](compile, "compileBatch", shouldLogQueue = true)
](compile, "compileBatch", shouldLogQueue = true, Some(Map.empty))
private val cascadeBatch =
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](compile, "cascadeBatch", shouldLogQueue = true)
](compile, "cascadeBatch", shouldLogQueue = true, Some(Map.empty))
def pauseables: List[Pauseable] = List(compileBatch, cascadeBatch)

private val isCompiling = TrieMap.empty[b.BuildTargetIdentifier, Boolean]
Expand Down Expand Up @@ -115,15 +115,6 @@ final class Compilations(
def cancel(): Unit = {
cascadeBatch.cancelAll()
compileBatch.cancelAll()
buildTargets.all
.flatMap { target =>
buildTargets.buildServerOf(target.getId())
}
.distinct
.foreach { conn =>
conn.cancelCompilations()
}

}

def recompileAll(): Future[Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2119,6 +2119,8 @@ class MetalsLspService(
}

def disconnectOldBuildServer(): Future[Unit] = {
compilations.cancel()
buildTargetClasses.cancel()
diagnostics.reset()
bspSession.foreach(connection =>
scribe.info(s"Disconnecting from ${connection.main.name} session...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ final class BuildTargetClasses(
: TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] =
TrieMap.empty[b.BuildTargetIdentifier, b.JvmEnvironmentItem]
val rebuildIndex: BatchedFunction[b.BuildTargetIdentifier, Unit] =
BatchedFunction.fromFuture(fetchClasses, "buildTargetClasses")
BatchedFunction.fromFuture(
fetchClasses,
"buildTargetClasses",
default = Some(()),
)

def classesOf(target: b.BuildTargetIdentifier): Classes = {
index.getOrElse(target, new Classes)
Expand Down Expand Up @@ -171,6 +175,10 @@ final class BuildTargetClasses(
val name = NameTransformer.decode(names.last)
descriptors.map(descriptor => Symbols.Global(prefix, descriptor(name)))
}

def cancel(): Unit = {
rebuildIndex.cancelAll()
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean)
Expand Down
54 changes: 49 additions & 5 deletions tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package tests

import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.util.Success

import scala.meta.internal.metals.BatchedFunction
import scala.meta.internal.metals.Cancelable
import scala.meta.internal.metals.CancelableFuture

class BatchedFunctionSuite extends BaseSuite {
test("batch") {
Expand Down Expand Up @@ -95,11 +99,51 @@ class BatchedFunctionSuite extends BaseSuite {

mkString.unpause()

assertDiffEqual(paused.value, None)
assertDiffEqual(paused2.value, None)

val unpaused2 = mkString(List("a", "b"))
assertDiffEqual(unpaused2.value, Some(Success("ab")))
for {
_ <- paused.failed
_ <- paused2.failed
res <- mkString(List("a", "b"))
_ = assertEquals(res, "ab")
} yield ()
}

test("cancel2") {
val executorService = Executors.newFixedThreadPool(10)
val ec2 = ExecutionContext.fromExecutor(executorService)
var i = 1
val stuckExample: BatchedFunction[String, String] =
new BatchedFunction(
(seq: Seq[String]) => {
seq.toList match {
case "loop" :: Nil =>
val future = Future.apply {
while (i == 1) {
Thread.sleep(1)
}
"loop-result"
}(ec2)
CancelableFuture[String](future, Cancelable { () => i = 2 })
case _ =>
CancelableFuture[String](
Future.successful("result"),
Cancelable.empty,
)
}
},
"stuck example",
default = Some("default"),
)(ec2)
val cancelled = stuckExample("loop")
assert(i == 1)
assert(cancelled.value.isEmpty)
val normal = stuckExample("normal")
stuckExample.cancelCurrent()
assert(i == 2)
for {
str <- cancelled
_ = assertEquals(str, "default")
str <- normal
_ = assertEquals(str, "result")
} yield ()
}
}
2 changes: 2 additions & 0 deletions tests/unit/src/test/scala/tests/ReportsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ class ReportsSuite extends BaseSuite {
assert(Files.exists(pathToReadMe.toNIO))
Files.delete(pathToReadMe.toNIO)
}

test("test") {}
}

0 comments on commit ce3c761

Please sign in to comment.