diff --git a/test/runtests.jl b/test/runtests.jl index ecd7598..8061496 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,7 @@ function test_and_sample_model( param_names=missing, progress=false, minimum_roundtrips=nothing, + rng=make_rng(), kwargs... ) # Make the tempered sampler. @@ -65,7 +66,7 @@ function test_and_sample_model( # Sample. samples_tempered = AbstractMCMC.sample( - model, sampler_tempered, num_iterations; + rng, model, sampler_tempered, num_iterations; callback=callback, progress=progress, initial_params=initial_params, kwargs... ) @@ -328,13 +329,26 @@ end adapt=false, # Make sure we have _some_ roundtrips. minimum_roundtrips=10, + rng=make_rng(), ) - compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true) + # Some swap strategies are not great. + ess_slack_ratio = if swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.SingleSwap} + 0.25 + else + 0.5 + end + compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true, compare_ess_slack=ess_slack_ratio) end end @testset "Turing.jl" begin + # Let's make a default seed we can `deepcopy` throughout to get reproducible results. + seed = 42 + + # And let's set the seed explicitly for reproducibility. + Random.seed!(seed) + # Instantiate model. DynamicPPL.@model function demo_model(x) s ~ Exponential() @@ -379,7 +393,7 @@ end # Sample using HMC. samples_hmc = sample( - model, sampler_hmc, num_iterations; + make_rng(seed), model, sampler_hmc, num_iterations; n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed. initial_params=copy(initial_params), progress=false @@ -393,7 +407,7 @@ end # Make sure that we get the "same" result when only using the inverse temperature 1. sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1]) chain_tempered = sample( - model, sampler_tempered, num_iterations; + make_rng(seed), model, sampler_tempered, num_iterations; n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed. initial_params=copy(initial_params), chain_type=MCMCChains.Chains, @@ -421,6 +435,7 @@ end param_names=param_names, progress=false, n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed. + rng=make_rng(seed), ) map_parameters!(b, chain_tempered) compare_chains( @@ -441,7 +456,7 @@ end # Sample using MALA. chain_mh = AbstractMCMC.sample( - model, sampler_mh, num_iterations; + make_rng(), model, sampler_mh, num_iterations; initial_params=copy(initial_params), progress=false, chain_type=MCMCChains.Chains, @@ -452,7 +467,7 @@ end # Make sure that we get the "same" result when only using the inverse temperature 1. sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1]) chain_tempered = sample( - model, sampler_tempered, num_iterations; + make_rng(), model, sampler_tempered, num_iterations; initial_params=copy(initial_params), chain_type=MCMCChains.Chains, param_names=param_names, @@ -476,7 +491,8 @@ end adapt=false, mean_swap_rate_bound=0.1, initial_params=copy(initial_params), - param_names=param_names + param_names=param_names, + rng=make_rng(), ) map_parameters!(b, chain_tempered) diff --git a/test/simple_gaussian.jl b/test/simple_gaussian.jl index 5b26bf5..2c49527 100644 --- a/test/simple_gaussian.jl +++ b/test/simple_gaussian.jl @@ -23,7 +23,7 @@ # Sample. @testset "TemperedSampler" begin chains_product = sample( - DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples; + make_rng(), DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples; initial_params, bundle_resolve_swaps=true, chain_type=Vector{MCMCChains.Chains}, @@ -31,24 +31,36 @@ discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=2e-1, + max_atol=5e-1 + ) end @testset "MultiSampler without swapping" begin chains_product = sample( - tempered_multimodel, rwmh_product, num_samples; + make_rng(), tempered_multimodel, rwmh_product, num_samples; initial_params, chain_type=Vector{MCMCChains.Chains}, progress=false, discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=2e-1, + max_atol=5e-1 + ) end @testset "MultiSampler with swapping (saveall=true)" begin chains_product = sample( - tempered_multimodel, rwmh_product_with_swap, num_samples; + make_rng(), tempered_multimodel, rwmh_product_with_swap, num_samples; initial_params, bundle_resolve_swaps=true, chain_type=Vector{MCMCChains.Chains}, @@ -56,19 +68,31 @@ discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=2e-1, + max_atol=5e-1 + ) end @testset "MultiSampler with swapping (saveall=true)" begin chains_product = sample( - tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples; + make_rng(), tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples; initial_params, chain_type=Vector{MCMCChains.Chains}, progress=false, discard_initial=num_burnin, thinning=thin, ) - test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + test_chains_with_monotonic_variance( + chains_product, + Zeros(length(chains_product)), + std_true_dict, + min_atol=3e-1, + max_atol=5e-1 + ) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index a59e5e4..d13319a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -16,18 +16,21 @@ function to_dict(c::MCMCChains.ChainDataFrame, col::Symbol) end """ - atol_for_chain(chain; significance=1e-2, kind=Statistics.mean) + atol_for_chain(chain; significance=0.05, kind=Statistics.mean, min_atol=Inf, max_atol=0) Return a dictionary of absolute tolerances for each parameter in `chain`, computed as the confidence interval width for the mean of the parameter with `significance`. """ -function atol_for_chain(chain; significance=1e-2, kind=Statistics.mean, min_atol=Inf) +function atol_for_chain(chain; significance=0.05, kind=Statistics.mean, min_atol=0, max_atol=Inf) param_names = names(chain, :parameters) # Can reject H0 if, say, `abs(mean(chain2) - mean(chain1)) > confidence_width`. # Or alternatively, compare means but with `atol` set to the `confidence_width`. # NOTE: Failure to reject, i.e. passing the tests, does not imply that the means are equal. mcse = to_dict(MCMCChains.mcse(chain; kind), :mcse) - return Dict(sym => min(min_atol, quantile(Normal(0, mcse[sym]), 1 - significance/2)) for sym in param_names) + return Dict( + sym => max(min_atol, min(max_atol, quantile(Normal(0, mcse[sym]), 1 - significance/2))) + for sym in param_names + ) end thin_to(chain, n) = chain[1:length(chain) รท n:end] @@ -48,7 +51,7 @@ end function test_means(chain::MCMCChains.Chains, mean_true::AbstractDict; n=length(chain), kwargs...) chain = thin_to(chain, n) atol = atol_for_chain(chain; kwargs...) - @info "mean" [(mean(chain[sym]), atol[sym]) for sym in names(chain, :parameters)] + @debug "mean" [(mean(chain[sym]), atol[sym]) for sym in names(chain, :parameters)] @test all(isapprox(mean(chain[sym]), 0, atol=atol[sym]) for sym in names(chain, :parameters)) end @@ -68,7 +71,7 @@ end function test_std(chain::MCMCChains.Chains, std_true::AbstractDict; n=length(chain), kwargs...) chain = thin_to(chain, n) atol = atol_for_chain(chain; kind=Statistics.std, kwargs...) - @info "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)] + @debug "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)] @test all(isapprox(std(chain[sym]), std_true[sym], atol=atol[sym]) for sym in names(chain, :parameters)) end @@ -118,10 +121,20 @@ and `std_true`, respectively. Also test that the standard deviation is monotonic - `significance`: The significance level of the test. - `kwargs...`: Passed to `atol_for_chain`. """ -function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-4, kwargs...) +function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=0.05, kwargs...) @testset "chain $i" for i = 1:length(chains) test_means(chains[i], mean_true[i]; kwargs...) test_std(chains[i], std_true[i]; kwargs...) end - test_std_monotonicity(chains; significance=0.05) + test_std_monotonicity(chains; significance=significance) end + +""" + make_rng([seed]) + +Create a random number generator. + +# Arguments +- `seed`: The seed for the random number generator. Default is `42`. +""" +make_rng(seed=42) = Random.Xoshiro(seed)