From 4efa39ac7a5925fa6b5bf32d3bd5684756dc89e7 Mon Sep 17 00:00:00 2001
From: Julius Krumbiegel <22495855+jkrumbiegel@users.noreply.github.com>
Date: Tue, 15 Oct 2024 09:55:33 +0200
Subject: [PATCH] Deterministic seeded random streams across runs (#194)

* Create failing test

* simplify

* feed all eval requests through a single channel

* formatting

* `var"` macros not necessary after all

* typos

* catch errors not seen locally to understand CI failures better

* add env to make `Random` available

* formatting

* remove copypasted code that unnecessarily created a new notebook file away from the project
---
 src/server.jl                             | 16 +++++++++++----
 src/worker.jl                             | 13 ++++++++++++
 test/examples/random_seed/Project.toml    |  2 ++
 test/examples/random_seed/random_seed.qmd | 17 +++++++++++++++
 test/testsets/random_seed.jl              | 25 +++++++++++++++++++++++
 5 files changed, 69 insertions(+), 4 deletions(-)
 create mode 100644 test/examples/random_seed/Project.toml
 create mode 100644 test/examples/random_seed/random_seed.qmd
 create mode 100644 test/testsets/random_seed.jl

diff --git a/src/server.jl b/src/server.jl
index f881d75..f1abab6 100644
--- a/src/server.jl
+++ b/src/server.jl
@@ -92,6 +92,14 @@ end
 
 # Implementation.
 
+function remote_eval_fetch_channeled(worker, expr)
+    code = quote
+        put!(stable_execution_task_channel_in, $(QuoteNode(expr)))
+        take!(stable_execution_task_channel_out)
+    end
+    return Malt.remote_eval_fetch(worker, code)
+end
+
 function init!(file::File, options::Dict)
     worker = file.worker
     Malt.remote_eval_fetch(worker, worker_init(file, options))
@@ -106,7 +114,7 @@ function refresh!(file::File, options::Dict)
         init!(file, options)
     end
     expr = :(refresh!($(options)))
-    Malt.remote_eval_fetch(file.worker, expr)
+    remote_eval_fetch_channeled(file.worker, expr)
 end
 
 """
@@ -648,7 +656,7 @@ function evaluate_raw_cells!(
                     $(chunk.cell_options),
                 ))
 
-                worker_results, expand_cell = Malt.remote_eval_fetch(f.worker, expr)
+                worker_results, expand_cell = remote_eval_fetch_channeled(f.worker, expr)
 
                 # When the result of the cell evaluation is a cell expansion
                 # then we insert the original cell contents before the expanded
@@ -827,7 +835,7 @@ function evaluate_raw_cells!(
                             # inline evaluation since you can't pass cell
                             # options and so `expand` will always be `false`.
                             worker_results, expand_cell =
-                                Malt.remote_eval_fetch(f.worker, expr)
+                                remote_eval_fetch_channeled(f.worker, expr)
                             expand_cell && error("inline code cells cannot be expanded")
                             remote = only(worker_results)
                             if !isnothing(remote.error)
@@ -888,7 +896,7 @@ function evaluate_params!(f, params::Dict{String})
         :(@eval getfield(Main, :Notebook) const $(Symbol(key::String)) = $value)
     end
     expr = Expr(:block, exprs...)
-    Malt.remote_eval_fetch(f.worker, expr)
+    remote_eval_fetch_channeled(f.worker, expr)
     return
 end
 
diff --git a/src/worker.jl b/src/worker.jl
index 986b2c0..68dc44d 100644
--- a/src/worker.jl
+++ b/src/worker.jl
@@ -2,6 +2,19 @@ function worker_init(f::File, options::Dict)
     project = WorkerSetup.LOADER_ENV[]
     return lock(WorkerSetup.WORKER_SETUP_LOCK) do
         return quote
+            # issue #192
+            # Malt itself uses a new task for each `remote_eval` and because of this, random number streams
+            # are not consistent across runs even if seeded, as each task introduces a new state for its
+            # task-local RNG. As a workaround, we feed all `remote_eval` requests through these channels, such
+            # that the task executing code is always the same.
+            const stable_execution_task_channel_out = Channel()
+            const stable_execution_task_channel_in = Channel() do chan
+                for expr in chan
+                    result = Core.eval(Main, expr)
+                    put!(stable_execution_task_channel_out, result)
+                end
+            end
+
             push!(LOAD_PATH, $(project))
 
             let QNW = task_local_storage(:QUARTO_NOTEBOOK_WORKER_OPTIONS, $(options)) do
diff --git a/test/examples/random_seed/Project.toml b/test/examples/random_seed/Project.toml
new file mode 100644
index 0000000..576ba3c
--- /dev/null
+++ b/test/examples/random_seed/Project.toml
@@ -0,0 +1,2 @@
+[deps]
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/test/examples/random_seed/random_seed.qmd b/test/examples/random_seed/random_seed.qmd
new file mode 100644
index 0000000..e97b76c
--- /dev/null
+++ b/test/examples/random_seed/random_seed.qmd
@@ -0,0 +1,17 @@
+---
+title: Random seed
+---
+
+```{julia}
+using Random
+Random.seed!(123)
+rand()
+```
+
+```{julia}
+rand()
+```
+
+```{julia}
+rand()
+```
diff --git a/test/testsets/random_seed.jl b/test/testsets/random_seed.jl
new file mode 100644
index 0000000..01853a4
--- /dev/null
+++ b/test/testsets/random_seed.jl
@@ -0,0 +1,25 @@
+include("../utilities/prelude.jl")
+
+@testset "seeded random numbers are consistent across runs" begin
+    notebook = joinpath(@__DIR__, "../examples/random_seed/random_seed.qmd")
+
+    server = QuartoNotebookRunner.Server()
+
+    jsons = map(1:2) do _
+        QuartoNotebookRunner.run!(server, notebook; showprogress = false)
+    end
+
+    _output(cell) = only(cell.outputs).data["text/plain"]
+
+    @test tryparse(Float64, _output(jsons[1].cells[2])) !== nothing
+    @test tryparse(Float64, _output(jsons[1].cells[4])) !== nothing
+    @test tryparse(Float64, _output(jsons[1].cells[6])) !== nothing
+
+    @test length(unique([_output(jsons[1].cells[i]) for i in [2, 4, 6]])) == 3
+
+    @test _output(jsons[1].cells[2]) == _output(jsons[2].cells[2])
+    @test _output(jsons[1].cells[4]) == _output(jsons[2].cells[4])
+    @test _output(jsons[1].cells[6]) == _output(jsons[2].cells[6])
+
+    close!(server)
+end