Skip to content

Commit

Permalink
Made the test suite a bit more consistent by using explicit RNGs here
Browse files Browse the repository at this point in the history
and there
  • Loading branch information
torfjelde committed Sep 30, 2024
1 parent 2a8ab0e commit a4ec00b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 22 deletions.
30 changes: 23 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ function test_and_sample_model(
param_names=missing,
progress=false,
minimum_roundtrips=nothing,
rng=make_rng(),
kwargs...
)
# Make the tempered sampler.
Expand All @@ -65,7 +66,7 @@ function test_and_sample_model(

# Sample.
samples_tempered = AbstractMCMC.sample(
model, sampler_tempered, num_iterations;
rng, model, sampler_tempered, num_iterations;
callback=callback, progress=progress, initial_params=initial_params,
kwargs...
)
Expand Down Expand Up @@ -328,13 +329,26 @@ end
adapt=false,
# Make sure we have _some_ roundtrips.
minimum_roundtrips=10,
rng=make_rng(),
)

compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true)
# Some swap strategies are not great.
ess_slack_ratio = if swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.SingleSwap}
0.25
else
0.5
end
compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true, compare_ess_slack=ess_slack_ratio)
end
end

@testset "Turing.jl" begin
# Let's make a default seed we can `deepcopy` throughout to get reproducible results.
seed = 42

# And let's set the seed explicitly for reproducibility.
Random.seed!(seed)

# Instantiate model.
DynamicPPL.@model function demo_model(x)
s ~ Exponential()
Expand Down Expand Up @@ -379,7 +393,7 @@ end

# Sample using HMC.
samples_hmc = sample(
model, sampler_hmc, num_iterations;
make_rng(seed), model, sampler_hmc, num_iterations;
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
initial_params=copy(initial_params),
progress=false
Expand All @@ -393,7 +407,7 @@ end
# Make sure that we get the "same" result when only using the inverse temperature 1.
sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1])
chain_tempered = sample(
model, sampler_tempered, num_iterations;
make_rng(seed), model, sampler_tempered, num_iterations;
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
initial_params=copy(initial_params),
chain_type=MCMCChains.Chains,
Expand Down Expand Up @@ -421,6 +435,7 @@ end
param_names=param_names,
progress=false,
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
rng=make_rng(seed),
)
map_parameters!(b, chain_tempered)
compare_chains(
Expand All @@ -441,7 +456,7 @@ end

# Sample using MALA.
chain_mh = AbstractMCMC.sample(
model, sampler_mh, num_iterations;
make_rng(), model, sampler_mh, num_iterations;
initial_params=copy(initial_params),
progress=false,
chain_type=MCMCChains.Chains,
Expand All @@ -452,7 +467,7 @@ end
# Make sure that we get the "same" result when only using the inverse temperature 1.
sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1])
chain_tempered = sample(
model, sampler_tempered, num_iterations;
make_rng(), model, sampler_tempered, num_iterations;
initial_params=copy(initial_params),
chain_type=MCMCChains.Chains,
param_names=param_names,
Expand All @@ -476,7 +491,8 @@ end
adapt=false,
mean_swap_rate_bound=0.1,
initial_params=copy(initial_params),
param_names=param_names
param_names=param_names,
rng=make_rng(),
)
map_parameters!(b, chain_tempered)

Expand Down
40 changes: 32 additions & 8 deletions test/simple_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,76 @@
# Sample.
@testset "TemperedSampler" begin
chains_product = sample(
DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples;
make_rng(), DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples;
initial_params,
bundle_resolve_swaps=true,
chain_type=Vector{MCMCChains.Chains},
progress=false,
discard_initial=num_burnin,
thinning=thin,
)
test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
test_chains_with_monotonic_variance(
chains_product,
Zeros(length(chains_product)),
std_true_dict,
min_atol=2e-1,
max_atol=5e-1
)
end

@testset "MultiSampler without swapping" begin
chains_product = sample(
tempered_multimodel, rwmh_product, num_samples;
make_rng(), tempered_multimodel, rwmh_product, num_samples;
initial_params,
chain_type=Vector{MCMCChains.Chains},
progress=false,
discard_initial=num_burnin,
thinning=thin,
)
test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
test_chains_with_monotonic_variance(
chains_product,
Zeros(length(chains_product)),
std_true_dict,
min_atol=2e-1,
max_atol=5e-1
)
end

@testset "MultiSampler with swapping (saveall=true)" begin
chains_product = sample(
tempered_multimodel, rwmh_product_with_swap, num_samples;
make_rng(), tempered_multimodel, rwmh_product_with_swap, num_samples;
initial_params,
bundle_resolve_swaps=true,
chain_type=Vector{MCMCChains.Chains},
progress=false,
discard_initial=num_burnin,
thinning=thin,
)
test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
test_chains_with_monotonic_variance(
chains_product,
Zeros(length(chains_product)),
std_true_dict,
min_atol=2e-1,
max_atol=5e-1
)
end

@testset "MultiSampler with swapping (saveall=true)" begin
chains_product = sample(
tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples;
make_rng(), tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples;
initial_params,
chain_type=Vector{MCMCChains.Chains},
progress=false,
discard_initial=num_burnin,
thinning=thin,
)
test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict)
test_chains_with_monotonic_variance(
chains_product,
Zeros(length(chains_product)),
std_true_dict,
min_atol=3e-1,
max_atol=5e-1
)
end
end

27 changes: 20 additions & 7 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@ function to_dict(c::MCMCChains.ChainDataFrame, col::Symbol)
end

"""
atol_for_chain(chain; significance=1e-2, kind=Statistics.mean)
atol_for_chain(chain; significance=0.05, kind=Statistics.mean, min_atol=Inf, max_atol=0)
Return a dictionary of absolute tolerances for each parameter in `chain`, computed
as the confidence interval width for the mean of the parameter with `significance`.
"""
function atol_for_chain(chain; significance=1e-2, kind=Statistics.mean, min_atol=Inf)
function atol_for_chain(chain; significance=0.05, kind=Statistics.mean, min_atol=0, max_atol=Inf)
param_names = names(chain, :parameters)
# Can reject H0 if, say, `abs(mean(chain2) - mean(chain1)) > confidence_width`.
# Or alternatively, compare means but with `atol` set to the `confidence_width`.
# NOTE: Failure to reject, i.e. passing the tests, does not imply that the means are equal.
mcse = to_dict(MCMCChains.mcse(chain; kind), :mcse)
return Dict(sym => min(min_atol, quantile(Normal(0, mcse[sym]), 1 - significance/2)) for sym in param_names)
return Dict(
sym => max(min_atol, min(max_atol, quantile(Normal(0, mcse[sym]), 1 - significance/2)))
for sym in param_names
)
end

thin_to(chain, n) = chain[1:length(chain) ÷ n:end]
Expand All @@ -48,7 +51,7 @@ end
function test_means(chain::MCMCChains.Chains, mean_true::AbstractDict; n=length(chain), kwargs...)
chain = thin_to(chain, n)
atol = atol_for_chain(chain; kwargs...)
@info "mean" [(mean(chain[sym]), atol[sym]) for sym in names(chain, :parameters)]
@debug "mean" [(mean(chain[sym]), atol[sym]) for sym in names(chain, :parameters)]
@test all(isapprox(mean(chain[sym]), 0, atol=atol[sym]) for sym in names(chain, :parameters))
end

Expand All @@ -68,7 +71,7 @@ end
function test_std(chain::MCMCChains.Chains, std_true::AbstractDict; n=length(chain), kwargs...)
chain = thin_to(chain, n)
atol = atol_for_chain(chain; kind=Statistics.std, kwargs...)
@info "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)]
@debug "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)]
@test all(isapprox(std(chain[sym]), std_true[sym], atol=atol[sym]) for sym in names(chain, :parameters))
end

Expand Down Expand Up @@ -118,10 +121,20 @@ and `std_true`, respectively. Also test that the standard deviation is monotonic
- `significance`: The significance level of the test.
- `kwargs...`: Passed to `atol_for_chain`.
"""
function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-4, kwargs...)
function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=0.05, kwargs...)
@testset "chain $i" for i = 1:length(chains)
test_means(chains[i], mean_true[i]; kwargs...)
test_std(chains[i], std_true[i]; kwargs...)
end
test_std_monotonicity(chains; significance=0.05)
test_std_monotonicity(chains; significance=significance)
end

"""
make_rng([seed])
Create a random number generator.
# Arguments
- `seed`: The seed for the random number generator. Default is `42`.
"""
make_rng(seed=42) = Random.Xoshiro(seed)

0 comments on commit a4ec00b

Please sign in to comment.