Skip to content

Commit

Permalink
Added testing of warmup steps
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 4, 2024
1 parent 3b4f6db commit f9142a6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,45 @@
@test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N)
end

@testset "Warm-up steps" begin
# Create a chain and discard initial samples.
Random.seed!(1234)
N = 100
num_warmup = 50

# Everything should be discarded here.
chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup)
@test length(chain) == N
@test !ismissing(chain[1].a)

# Repeat sampling without discarding initial samples.
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
Random.seed!(1234)
ref_chain = sample(
MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6"
)
@test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N)
@test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N)

# Some other stuff.
Random.seed!(1234)
discard_initial = 10
chain_warmup = sample(
MyModel(),
MySampler(),
N;
num_warmup=num_warmup,
discard_initial=discard_initial,
)
@test length(chain_warmup) == N
@test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N)
# Check that the first `num_warmup - discard_initial` samples are warmup samples.
@test all(
chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N
)
end

@testset "Thin chain by a factor of `thinning`" begin
# Run a thinned chain with `N` samples thinned by factor of `thinning`.
Random.seed!(100)
Expand Down
18 changes: 18 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end
struct MySample{A,B}
a::A
b::B
is_warmup::Bool
end

MySample(a, b) = MySample(a, b, false)

struct MySampler <: AbstractMCMC.AbstractSampler end
struct AnotherSampler <: AbstractMCMC.AbstractSampler end

Expand All @@ -16,6 +19,21 @@ end

MyChain(a, b) = MyChain(a, b, NamedTuple())

function AbstractMCMC.step_warmup(
rng::AbstractRNG,
model::MyModel,
sampler::MySampler,
state::Union{Nothing,Integer}=nothing;
loggers=false,
initial_params=nothing,
kwargs...,
)
transition, state = AbstractMCMC.step(
rng, model, sampler, state; loggers, initial_params, kwargs...
)
return MySample(transition.a, transition.b, true), state
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::MyModel,
Expand Down

0 comments on commit f9142a6

Please sign in to comment.