Skip to content

Commit

Permalink
Merge pull request #47 from lucidsoftware/fix-print-output-and-bootst…
Browse files Browse the repository at this point in the history
…rap-nondeterminism

Fix non-worker output and also bootstrap scalac nondeterminism
  • Loading branch information
jjudd authored Jul 26, 2024
2 parents f781c28 + 2ba3db6 commit c8c4345
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 68 deletions.
44 changes: 24 additions & 20 deletions rules/private/phases/phase_bootstrap_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def phase_bootstrap_compile(ctx, g):
transitive = [ctx.attr._jdk[java_common.JavaRuntimeInfo].files, g.classpaths.compile, g.classpaths.compiler],
)

tmp = ctx.actions.declare_directory("{}/tmp/classes".format(ctx.label.name))

compiler_classpath = ":".join([f.path for f in g.classpaths.compiler.to_list()])
compile_classpath = ":".join([f.path for f in g.classpaths.compile.to_list()])
srcs = " ".join([f.path for f in g.classpaths.srcs])
Expand All @@ -31,37 +33,39 @@ def phase_bootstrap_compile(ctx, g):
if int(scala_configuration.version[0]) >= 3:
main_class = "dotty.tools.dotc.Main"

ctx.actions.run_shell(
inputs = inputs,
tools = [ctx.executable._jar_creator],
outputs = [g.classpaths.jar],
command = _strip_margin(
"""
command = _strip_margin(
"""
|set -eo pipefail
|
|mkdir -p tmp/classes
|
|{java} \\
| -cp {compiler_classpath} \\
| {main_class} \\
| -cp {compile_classpath} \\
| -d tmp/classes \\
| -d {tmp} \\
| {global_scalacopts} \\
| {scalacopts} \\
| {srcs}
|
|{jar_creator} {output_jar} tmp/classes 2> /dev/null
|{jar_creator} {output_jar} {tmp} 2> /dev/null
|""".format(
compiler_classpath = compiler_classpath,
compile_classpath = compile_classpath,
global_scalacopts = " ".join(scala_configuration.global_scalacopts),
java = ctx.attr._jdk[java_common.JavaRuntimeInfo].java_executable_exec_path,
jar_creator = ctx.executable._jar_creator.path,
main_class = main_class,
output_jar = g.classpaths.jar.path,
scalacopts = " ".join(ctx.attr.scalacopts),
srcs = srcs,
),
compiler_classpath = compiler_classpath,
compile_classpath = compile_classpath,
global_scalacopts = " ".join(scala_configuration.global_scalacopts),
java = ctx.attr._jdk[java_common.JavaRuntimeInfo].java_executable_exec_path,
jar_creator = ctx.executable._jar_creator.path,
main_class = main_class,
output_jar = g.classpaths.jar.path,
scalacopts = " ".join(ctx.attr.scalacopts),
srcs = srcs,
tmp = tmp.path,
),
)

ctx.actions.run_shell(
inputs = inputs,
tools = [ctx.executable._jar_creator],
mnemonic = "BootstrapScalacompile",
outputs = [g.classpaths.jar, tmp],
command = command,
execution_requirements = _resolve_execution_reqs(ctx, {}),
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, OutputStream, Print
import java.util.concurrent.ForkJoinPool
import scala.annotation.tailrec
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Using}

trait WorkerMain[S] {

Expand Down Expand Up @@ -38,9 +38,6 @@ trait WorkerMain[S] {
)
val ec = ExecutionContext.fromExecutor(fjp)

System.setIn(new ByteArrayInputStream(Array.emptyByteArray))
System.setOut(System.err)

def writeResponse(requestId: Int, outStream: OutputStream, code: Int) = {
// Defined here so all writes to stdout are synchronized
stdout.synchronized {
Expand All @@ -53,56 +50,84 @@ trait WorkerMain[S] {
}
}

try {
@tailrec
def process(ctx: S): Unit = {
val request = WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin)
if (request == null) {
return
}
val args = request.getArgumentsList.toArray(Array.empty[String])
@tailrec
def process(ctx: S): Unit = {
val request = WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin)
if (request == null) {
return
}
val args = request.getArgumentsList.toArray(Array.empty[String])

val outStream = new ByteArrayOutputStream
val out = new PrintStream(outStream)
val requestId = request.getRequestId()
System.out.println(s"WorkRequest $requestId received with args: ${request.getArgumentsList}")
// We go through this hullabaloo with output streams being defined out here, so we can
// close them after the async work in the Future is all done.
// If we do something synchronous with Using, then there's a race condition where the
// streams can get closed before the Future is completed.
var outStream: ByteArrayOutputStream = null
var out: PrintStream = null

val f: Future[Int] = Future {
try {
work(ctx, args, out)
0
} catch {
case AnnexWorkerError(code, _, _) => code
}
}(ec)
val requestId = request.getRequestId()
System.out.println(s"WorkRequest $requestId received with args: ${request.getArgumentsList}")

f.onComplete {
case Success(code) => {
out.flush()
writeResponse(requestId, outStream, code)
System.out.println(s"WorkResponse $requestId sent with code $code")
}
case Failure(e) => {
e.printStackTrace(out)
out.flush()
writeResponse(requestId, outStream, -1)
System.err.println(s"Uncaught exception in Future while proccessing WorkRequest $requestId:")
e.printStackTrace(System.err)
}
val f: Future[Int] = Future {
outStream = new ByteArrayOutputStream
out = new PrintStream(outStream)
try {
work(ctx, args, out)
0
} catch {
case AnnexWorkerError(code, _, _) => code
}
}(ec)

f.andThen {
case Success(code) => {
out.flush()
writeResponse(requestId, outStream, code)
System.out.println(s"WorkResponse $requestId sent with code $code")
}
case Failure(e) => {
e.printStackTrace(out)
out.flush()
writeResponse(requestId, outStream, -1)
System.err.println(s"Uncaught exception in Future while proccessing WorkRequest $requestId:")
e.printStackTrace(System.err)
}
}(scala.concurrent.ExecutionContext.global)
.andThen { case _ =>
out.close()
outStream.close()
}(scala.concurrent.ExecutionContext.global)
process(ctx)
process(ctx)
}

Using.resource(new ByteArrayInputStream(Array.emptyByteArray)) { inStream =>
try {
System.setIn(inStream)
System.setOut(System.err)
process(init(Some(args.toArray)))
} finally {
System.setIn(stdin)
System.setOut(stdout)
}
process(init(Some(args.toArray)))
} finally {
System.setIn(stdin)
System.setOut(stdout)
}

case args => {
val outStream = new ByteArrayOutputStream
val out = new PrintStream(outStream)
work(init(None), args.toArray, out)
}
case args =>
Using.Manager { use =>
val outStream = use(new ByteArrayOutputStream)
val out = use(new PrintStream(outStream))
try {
work(init(None), args.toArray, out)
} catch {
// This error means the work function encountered an error that we want to not be caught
// inside that function. That way it stops work and exits the function. However, we
// also don't want to crash the whole program.
case AnnexWorkerError(_, _, _) => {}
} finally {
out.flush()
}

outStream.writeTo(System.err)
}.get
}
}
}
3 changes: 2 additions & 1 deletion tests/compile/error/test
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash -e
. "$(dirname "$0")"/../../common.sh

bazel build :lib 2>&1 | grep -q $'\[\e\[31mError\e\[0m\] compile/error/Example\.scala:'
bazel build --strategy=ScalaCompile=worker :lib 2>&1 | grep -q $'\[\e\[31mError\e\[0m\] compile/error/Example\.scala:'
bazel build --strategy=ScalaCompile=local :lib 2>&1 | grep -q $'\[\e\[31mError\e\[0m\] compile/error/Example\.scala:'

0 comments on commit c8c4345

Please sign in to comment.