Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvement: cancel current request in batched functions #5432

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package scala.meta.internal.metals

import java.util.concurrent.CancellationException
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.util.control.NonFatal

import scala.meta.internal.async.ConcurrentQueue
Expand All @@ -22,6 +25,7 @@ final class BatchedFunction[A, B](
fn: Seq[A] => CancelableFuture[B],
functionId: String,
shouldLogQueue: Boolean = false,
default: Option[B] = None,
)(implicit ec: ExecutionContext)
extends (Seq[A] => Future[B])
with Function2[Seq[A], () => Unit, Future[B]]
Expand Down Expand Up @@ -75,8 +79,17 @@ final class BatchedFunction[A, B](
}

def cancelAll(): Unit = {
queue.clear()
unlock()
val requests = ConcurrentQueue.pollAll(queue)
requests.foreach(_.result.complete(defaultResult))
cancelCurrent()
}

def cancelCurrent(): Unit = {
lock.get() match {
case None =>
case Some(promise) =>
promise.tryFailure(new BatchedFunction.BatchedFunctionCancelation)
}
}

def currentFuture(): Future[B] = {
Expand All @@ -97,22 +110,28 @@ final class BatchedFunction[A, B](
callback: () => Unit,
)

private val lock = new AtomicBoolean()
private val lock = new AtomicReference[Option[Promise[B]]](None)

private def unlock(): Unit = {
lock.set(false)
lock.set(None)
if (!queue.isEmpty) {
runAcquire()
}
}
private def runAcquire(): Unit = {
if (!isPaused.get() && lock.compareAndSet(false, true)) {
runRelease()
lazy val promise = {
val p = Promise[B]
p.future.onComplete { _ => unlock() }
p
}
if (!isPaused.get() && lock.compareAndSet(None, Some(promise))) {
runRelease(promise)
} else {
// Do nothing, the submitted arguments will be handled
// by a separate request.
}
}
private def runRelease(): Unit = {
private def runRelease(p: Promise[B]): Unit = {
// Pre-condition: lock is acquired.
// Pos-condition:
// - lock is released
Expand All @@ -128,37 +147,45 @@ final class BatchedFunction[A, B](
this.current.set(result)
val resultF = for {
result <- result.future
_ <- Future {
callbacks.foreach(cb => cb())
}
_ <- Future { callbacks.foreach(cb => cb()) }
} yield result
resultF.onComplete { response =>
unlock()
requests.foreach(_.result.complete(response))
resultF.onComplete(p.tryComplete)
p.future.onComplete {
case Failure(_: BatchedFunction.BatchedFunctionCancelation) =>
result.cancel()
requests.foreach(_.result.complete(defaultResult))
case result =>
requests.foreach(_.result.complete(result))
}
} else {
unlock()
p.tryFailure(new BatchedFunction.BatchedFunctionCancelation)
}
} catch {
case NonFatal(e) =>
unlock()
requests.foreach(_.result.failure(e))
requests.foreach(_.result.tryFailure(e))
scribe.error(s"Unexpected error releasing buffered job", e)
}
}

def defaultResult: Try[B] =
default.map(Success(_)).getOrElse(Failure(new CancellationException))
}

object BatchedFunction {
def fromFuture[A, B](
fn: Seq[A] => Future[B],
functionId: String,
shouldLogQueue: Boolean = false,
default: Option[B] = None,
)(implicit
ec: ExecutionContext
): BatchedFunction[A, B] =
new BatchedFunction(
fn.andThen(CancelableFuture(_)),
functionId,
shouldLogQueue,
default,
)
class BatchedFunctionCancelation extends RuntimeException
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import scala.concurrent.ExecutionContextExecutorService
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.reflect.ClassTag
import scala.util.Success
import scala.util.Try

import scala.meta.internal.builds.MillBuildTool
Expand Down Expand Up @@ -73,7 +74,6 @@ class BuildServerConnection private (

private val ongoingRequests =
new MutableCancelable().addAll(initialConnection.cancelables)
private val ongoingCompilations = new MutableCancelable()

def version: String = _version.get()

Expand Down Expand Up @@ -154,7 +154,6 @@ class BuildServerConnection private (
def compile(params: CompileParams): CompletableFuture[CompileResult] = {
register(
server => server.buildTargetCompile(params),
isCompile = true,
onFail = Some(
(
new CompileResult(StatusCode.CANCELLED),
Expand Down Expand Up @@ -299,14 +298,9 @@ class BuildServerConnection private (
override def cancel(): Unit = {
if (cancelled.compareAndSet(false, true)) {
ongoingRequests.cancel()
ongoingCompilations.cancel()
}
}

def cancelCompilations(): Unit = {
ongoingCompilations.cancel()
}

private def askUser(
original: Future[BuildServerConnection.LauncherConnection]
): Future[BuildServerConnection.LauncherConnection] = {
Expand Down Expand Up @@ -356,24 +350,23 @@ class BuildServerConnection private (
private def register[T: ClassTag](
action: MetalsBuildServer => CompletableFuture[T],
onFail: => Option[(T, String)] = None,
isCompile: Boolean = false,
): CompletableFuture[T] = {

val localCancelable = new MutableCancelable()
def runWithCanceling(
launcherConnection: BuildServerConnection.LauncherConnection
): Future[T] = {
val resultFuture = action(launcherConnection.server)
val cancelable = Cancelable { () =>
Try(resultFuture.cancel(true))
}
if (isCompile) ongoingCompilations.add(cancelable)
else ongoingRequests.add(cancelable)
ongoingRequests.add(cancelable)
localCancelable.add(cancelable)

val result = resultFuture.asScala

result.onComplete { _ =>
if (isCompile) ongoingCompilations.remove(cancelable)
else ongoingRequests.remove(cancelable)
ongoingRequests.remove(cancelable)
localCancelable.remove(cancelable)
}
result
}
Expand Down Expand Up @@ -410,7 +403,14 @@ class BuildServerConnection private (
Future.failed(new MetalsBspException(name, t))
})
}
CancelTokens.future(_ => actionFuture)

CancelTokens.future { token =>
token.onCancel().asScala.onComplete {
case Success(java.lang.Boolean.TRUE) => localCancelable.cancel()
case _ =>
}
actionFuture
}
}

def isBuildServerResponsive: Future[Option[Boolean]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ final class Compilations(
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](compile, "compileBatch", shouldLogQueue = true)
](compile, "compileBatch", shouldLogQueue = true, Some(Map.empty))
private val cascadeBatch =
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](compile, "cascadeBatch", shouldLogQueue = true)
](compile, "cascadeBatch", shouldLogQueue = true, Some(Map.empty))
def pauseables: List[Pauseable] = List(compileBatch, cascadeBatch)

private val isCompiling = TrieMap.empty[b.BuildTargetIdentifier, Boolean]
Expand Down Expand Up @@ -115,15 +115,6 @@ final class Compilations(
def cancel(): Unit = {
cascadeBatch.cancelAll()
compileBatch.cancelAll()
buildTargets.all
.flatMap { target =>
buildTargets.buildServerOf(target.getId())
}
.distinct
.foreach { conn =>
conn.cancelCompilations()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was done on purpose since the cancelation was not properly propagated before

}

}

def recompileAll(): Future[Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,8 @@ class MetalsLspService(
}

def disconnectOldBuildServer(): Future[Unit] = {
compilations.cancel()
buildTargetClasses.cancel()
diagnostics.reset()
bspSession.foreach(connection =>
scribe.info(s"Disconnecting from ${connection.main.name} session...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ final class BuildTargetClasses(
: TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] =
TrieMap.empty[b.BuildTargetIdentifier, b.JvmEnvironmentItem]
val rebuildIndex: BatchedFunction[b.BuildTargetIdentifier, Unit] =
BatchedFunction.fromFuture(fetchClasses, "buildTargetClasses")
BatchedFunction.fromFuture(
fetchClasses,
"buildTargetClasses",
default = Some(()),
)

def classesOf(target: b.BuildTargetIdentifier): Classes = {
index.getOrElse(target, new Classes)
Expand Down Expand Up @@ -175,6 +179,10 @@ final class BuildTargetClasses(
val name = NameTransformer.decode(names.last)
descriptors.map(descriptor => Symbols.Global(prefix, descriptor(name)))
}

def cancel(): Unit = {
rebuildIndex.cancelAll()
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean)
Expand Down
54 changes: 49 additions & 5 deletions tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package tests

import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.util.Success

import scala.meta.internal.metals.BatchedFunction
import scala.meta.internal.metals.Cancelable
import scala.meta.internal.metals.CancelableFuture

class BatchedFunctionSuite extends BaseSuite {
test("batch") {
Expand Down Expand Up @@ -95,11 +99,51 @@ class BatchedFunctionSuite extends BaseSuite {

mkString.unpause()

assertDiffEqual(paused.value, None)
assertDiffEqual(paused2.value, None)

val unpaused2 = mkString(List("a", "b"))
assertDiffEqual(unpaused2.value, Some(Success("ab")))
for {
_ <- paused.failed
_ <- paused2.failed
res <- mkString(List("a", "b"))
_ = assertEquals(res, "ab")
} yield ()
}

test("cancel2") {
val executorService = Executors.newFixedThreadPool(10)
val ec2 = ExecutionContext.fromExecutor(executorService)
var i = 1
val stuckExample: BatchedFunction[String, String] =
new BatchedFunction(
(seq: Seq[String]) => {
seq.toList match {
case "loop" :: Nil =>
val future = Future.apply {
tgodzik marked this conversation as resolved.
Show resolved Hide resolved
while (i == 1) {
Thread.sleep(1)
}
"loop-result"
}(ec2)
CancelableFuture[String](future, Cancelable { () => i = 2 })
case _ =>
CancelableFuture[String](
Future.successful("result"),
Cancelable.empty,
)
}
},
"stuck example",
default = Some("default"),
)(ec2)
val cancelled = stuckExample("loop")
assertEquals(i, 1)
assert(cancelled.value.isEmpty)
val normal = stuckExample("normal")
stuckExample.cancelCurrent()
for {
str <- cancelled
_ = assertEquals(i, 2)
_ = assertEquals(str, "default")
str <- normal
_ = assertEquals(str, "result")
} yield ()
}
}
Loading
Loading