From 4a78972ff534fc6b68c8c9033661089535b6b2fc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 29 Nov 2024 23:28:35 +0000 Subject: [PATCH] Implement getstepsize() for NoAdaptation samplers --- Project.toml | 2 +- src/mcmc/hmc.jl | 6 ++++++ test/mcmc/hmc.jl | 13 +++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2bcc8bee5..cea6b3655 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.35.3" +version = "0.35.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 5887feb5e..5f1caead2 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -463,6 +463,12 @@ end getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor) +function getstepsize( + sampler::Sampler{<:AdaptiveHamiltonian}, + state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation}, +) where {TV,TKernel,THam,PhType} + return state.kernel.τ.integrator.ϵ +end gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim) function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 27c928896..91dce486a 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -329,6 +329,19 @@ using Turing @test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001 end + @testset "getstepsize: Turing.jl#2400" begin + algs = [HMC(0.1, 10), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)] + @testset "$(alg)" for alg in algs + # Construct a HMC state by taking a single step + spl = Sampler(alg, gdemo_default) + hmc_state = DynamicPPL.initialstep( + Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default) + )[2] + # Check that we can obtain the current step size + @test Turing.Inference.getstepsize(spl, hmc_state) isa Float64 + end + end + @testset "Check ADType" begin alg = HMC(0.1, 10; adtype=adbackend) m = DynamicPPL.contextualize(