Skip to content

Commit

Permalink
Deterministic seeded random streams across runs (#194)
Browse files Browse the repository at this point in the history
* Create failing test

* simplify

* feed all eval requests through a single channel

* formatting

* `var"` macros not necessary after all

* typos

* catch errors not seen locally to understand CI failures better

* add env to make `Random` available

* formatting

* remove copypasted code that unnecessarily created a new notebook file away from the project
  • Loading branch information
jkrumbiegel authored Oct 15, 2024
1 parent 84cb855 commit 4efa39a
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/server.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ end

# Implementation.

function remote_eval_fetch_channeled(worker, expr)
code = quote
put!(stable_execution_task_channel_in, $(QuoteNode(expr)))
take!(stable_execution_task_channel_out)
end
return Malt.remote_eval_fetch(worker, code)
end

function init!(file::File, options::Dict)
worker = file.worker
Malt.remote_eval_fetch(worker, worker_init(file, options))
Expand All @@ -106,7 +114,7 @@ function refresh!(file::File, options::Dict)
init!(file, options)
end
expr = :(refresh!($(options)))
Malt.remote_eval_fetch(file.worker, expr)
remote_eval_fetch_channeled(file.worker, expr)
end

"""
Expand Down Expand Up @@ -648,7 +656,7 @@ function evaluate_raw_cells!(
$(chunk.cell_options),
))

worker_results, expand_cell = Malt.remote_eval_fetch(f.worker, expr)
worker_results, expand_cell = remote_eval_fetch_channeled(f.worker, expr)

# When the result of the cell evaluation is a cell expansion
# then we insert the original cell contents before the expanded
Expand Down Expand Up @@ -827,7 +835,7 @@ function evaluate_raw_cells!(
# inline evaluation since you can't pass cell
# options and so `expand` will always be `false`.
worker_results, expand_cell =
Malt.remote_eval_fetch(f.worker, expr)
remote_eval_fetch_channeled(f.worker, expr)
expand_cell && error("inline code cells cannot be expanded")
remote = only(worker_results)
if !isnothing(remote.error)
Expand Down Expand Up @@ -888,7 +896,7 @@ function evaluate_params!(f, params::Dict{String})
:(@eval getfield(Main, :Notebook) const $(Symbol(key::String)) = $value)
end
expr = Expr(:block, exprs...)
Malt.remote_eval_fetch(f.worker, expr)
remote_eval_fetch_channeled(f.worker, expr)
return
end

Expand Down
13 changes: 13 additions & 0 deletions src/worker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ function worker_init(f::File, options::Dict)
project = WorkerSetup.LOADER_ENV[]
return lock(WorkerSetup.WORKER_SETUP_LOCK) do
return quote
# issue #192
# Malt itself uses a new task for each `remote_eval` and because of this, random number streams
# are not consistent across runs even if seeded, as each task introduces a new state for its
# task-local RNG. As a workaround, we feed all `remote_eval` requests through these channels, such
# that the task executing code is always the same.
const stable_execution_task_channel_out = Channel()
const stable_execution_task_channel_in = Channel() do chan
for expr in chan
result = Core.eval(Main, expr)
put!(stable_execution_task_channel_out, result)
end
end

push!(LOAD_PATH, $(project))

let QNW = task_local_storage(:QUARTO_NOTEBOOK_WORKER_OPTIONS, $(options)) do
Expand Down
2 changes: 2 additions & 0 deletions test/examples/random_seed/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17 changes: 17 additions & 0 deletions test/examples/random_seed/random_seed.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
title: Random seed
---

```{julia}
using Random
Random.seed!(123)
rand()
```

```{julia}
rand()
```

```{julia}
rand()
```
25 changes: 25 additions & 0 deletions test/testsets/random_seed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
include("../utilities/prelude.jl")

@testset "seeded random numbers are consistent across runs" begin
notebook = joinpath(@__DIR__, "../examples/random_seed/random_seed.qmd")

server = QuartoNotebookRunner.Server()

jsons = map(1:2) do _
QuartoNotebookRunner.run!(server, notebook; showprogress = false)
end

_output(cell) = only(cell.outputs).data["text/plain"]

@test tryparse(Float64, _output(jsons[1].cells[2])) !== nothing
@test tryparse(Float64, _output(jsons[1].cells[4])) !== nothing
@test tryparse(Float64, _output(jsons[1].cells[6])) !== nothing

@test length(unique([_output(jsons[1].cells[i]) for i in [2, 4, 6]])) == 3

@test _output(jsons[1].cells[2]) == _output(jsons[2].cells[2])
@test _output(jsons[1].cells[4]) == _output(jsons[2].cells[4])
@test _output(jsons[1].cells[6]) == _output(jsons[2].cells[6])

close!(server)
end

0 comments on commit 4efa39a

Please sign in to comment.