From f9142a6a205144b79adba27e8b7ca258157d5b4f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 12:07:49 +0100 Subject: [PATCH] Added testing of warmup steps --- test/sample.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/utils.jl | 18 ++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/test/sample.jl b/test/sample.jl index dcc87526..7599bd79 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -575,6 +575,45 @@ @test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N) end + @testset "Warm-up steps" begin + # Create a chain and discard initial samples. + Random.seed!(1234) + N = 100 + num_warmup = 50 + + # Everything should be discarded here. + chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup) + @test length(chain) == N + @test !ismissing(chain[1].a) + + # Repeat sampling without discarding initial samples. + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + ref_chain = sample( + MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N) + @test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N) + + # Some other stuff. + Random.seed!(1234) + discard_initial = 10 + chain_warmup = sample( + MyModel(), + MySampler(), + N; + num_warmup=num_warmup, + discard_initial=discard_initial, + ) + @test length(chain_warmup) == N + @test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N) + # Check that the first `num_warmup - discard_initial` samples are warmup samples. + @test all( + chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N + ) + end + @testset "Thin chain by a factor of `thinning`" begin # Run a thinned chain with `N` samples thinned by factor of `thinning`. Random.seed!(100) diff --git a/test/utils.jl b/test/utils.jl index 1e29a473..b041b3a7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end struct MySample{A,B} a::A b::B + is_warmup::Bool end +MySample(a, b) = MySample(a, b, false) + struct MySampler <: AbstractMCMC.AbstractSampler end struct AnotherSampler <: AbstractMCMC.AbstractSampler end @@ -16,6 +19,21 @@ end MyChain(a, b) = MyChain(a, b, NamedTuple()) +function AbstractMCMC.step_warmup( + rng::AbstractRNG, + model::MyModel, + sampler::MySampler, + state::Union{Nothing,Integer}=nothing; + loggers=false, + initial_params=nothing, + kwargs..., +) + transition, state = AbstractMCMC.step( + rng, model, sampler, state; loggers, initial_params, kwargs... + ) + return MySample(transition.a, transition.b, true), state +end + function AbstractMCMC.step( rng::AbstractRNG, model::MyModel,