From f6e49acc311198b3cc9b99161281489976eec6d9 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Feb 2021 19:23:46 +0000 Subject: [PATCH 1/8] unify trajectories --- src/AdvancedHMC.jl | 32 ++++++-- src/trajectory.jl | 180 ++++++++++---------------------------------- test/demo.jl | 2 +- test/sampler-vec.jl | 16 ++-- test/sampler.jl | 21 +++--- 5 files changed, 85 insertions(+), 166 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 55e7673d..99812974 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -34,13 +34,35 @@ include("integrator.jl") export Leapfrog, JitteredLeapfrog, TemperedLeapfrog include("trajectory.jl") -@deprecate find_good_eps find_good_stepsize -export EndPointTS, SliceTS, MultinomialTS, - StaticTrajectory, HMCDA, NUTS, - ClassicNoUTurn, GeneralisedNoUTurn, - StrictGeneralisedNoUTurn, +export Trajectory, + FixedNSteps, FixedIntegrationTime, + ClassicNoUTurn, GeneralisedNoUTurn, StrictGeneralisedNoUTurn, + EndPointTS, SliceTS, MultinomialTS, find_good_stepsize +# Deprecations for trajectory.jl + +struct AbstractTrajectory <: AbstractProposal end + +struct StaticTrajectory{TS} end +@deprecate StaticTrajectory{TS}(int::AbstractIntegrator, L) where {TS} Trajectory{TS}(int, FixedNSteps(L)) +@deprecate StaticTrajectory(int::AbstractIntegrator, L) Trajectory{EndPointTS}(int, FixedNSteps(L)) +@deprecate StaticTrajectory(ϵ::AbstractScalarOrVec{<:Real}, L) Trajectory{EndPointTS}(Leapfrog(ϵ), FixedNSteps(L)) + +struct HMCDA{TS} end +@deprecate HMCDA{TS}(int::AbstractIntegrator, λ) where {TS} Trajectory{TS}(int, FixedIntegrationTime(λ)) +@deprecate HMCDA(int::AbstractIntegrator, λ) Trajectory{MetropolisTS}(int, FixedIntegrationTime(λ)) +@deprecate HMCDA(ϵ::AbstractScalarOrVec{<:Real}, λ) Trajectory{MetropolisTS}(Leapfrog(ϵ), FixedIntegrationTime(λ)) + +struct NUTS{TS, TC} end +@deprecate NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} Trajectory{TS}(int, TC(args...; kwargs...)) +@deprecate NUTS(int::AbstractIntegrator, args...; kwargs...) Trajectory{MultinomialTS}(int, GeneralisedNoUTurn(args...; kwargs...)) +@deprecate NUTS(ϵ::AbstractScalarOrVec{<:Real}) Trajectory{MultinomialTS}(Leapfrog(ϵ), GeneralisedNoUTurn()) + +@deprecate find_good_eps find_good_stepsize + +export AbstractTrajectory, StaticTrajectory, HMCDA, NUTS, find_good_eps + include("adaptation/Adaptation.jl") using .Adaptation import .Adaptation: StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging diff --git a/src/trajectory.jl b/src/trajectory.jl index d97e9dc0..775b3c46 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -63,9 +63,6 @@ struct FixedIntegrationTime{F<:AbstractFloat} <: StaticTerminationCriterion λ::F end -"Hamiltonian dynamics numerical simulation trajectories." -abstract type AbstractTrajectory{I<:AbstractIntegrator} <: AbstractProposal end - ## ## Sampling methods for trajectories. ## @@ -183,6 +180,25 @@ function mh_accept(rng::AbstractRNG, s::MultinomialTS, s′::MultinomialTS) return rand(rng) < min(1, exp(s′.ℓw - s.ℓw)) end +""" +$(TYPEDEF) + +Numerically simulated Hamiltonian trajectories. +""" +struct Trajectory{TS<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:AbstractTerminationCriterion} <: AbstractProposal + "Integrator used to simulate trajectory." + integrator::I + "Criterion to terminate the simulation." + termination_criterion::TC +end + +Trajectory{TS}(integrator::I, termination_criterion::TC) where {TS, I, TC} = + Trajectory{TS, I, TC}(integrator, termination_criterion) + +function Base.show(io::IO, τ::Trajectory{TS}) where {TS} + print(io, "Trajectory{$TS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))") +end + """ $(SIGNATURES) @@ -190,7 +206,7 @@ Make a MCMC transition from phase point `z` using the trajectory `τ` under Hami NOTE: This is a RNG-implicit fallback function for `transition(GLOBAL_RNG, τ, h, z)` """ -function transition(τ::AbstractTrajectory, h::Hamiltonian, z::PhasePoint) +function transition(τ::Trajectory, h::Hamiltonian, z::PhasePoint) return transition(GLOBAL_RNG, τ, h, z) end @@ -198,46 +214,12 @@ end ### Actual trajectory implementations ### -""" -$(TYPEDEF) -Static HMC with a fixed number of leapfrog steps. - -# Fields -$(TYPEDFIELDS) - -# References -1. Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov chain Monte Carlo, 2(11), 2. ([arXiv](https://arxiv.org/pdf/1206.1901)) -""" -struct StaticTrajectory{ - S<:AbstractTrajectorySampler, - I<:AbstractIntegrator, - TC<:StaticTerminationCriterion -} <: AbstractTrajectory{I} - "Integrator used to simulate trajectory." - integrator :: I - termination_criterion :: TC -end - -function Base.show(io::IO, τ::StaticTrajectory{<:EndPointTS}) - print(io, "StaticTrajectory{EndPointTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))") -end - -function Base.show(io::IO, τ::StaticTrajectory{<:MultinomialTS}) - print(io, "StaticTrajectory{MultinomialTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))") -end - -function StaticTrajectory{S}(integrator::I, n_steps::Int) where {S,I} - tc = FixedNSteps(n_steps) - return StaticTrajectory{S, I, typeof(tc)}(integrator, tc) -end -StaticTrajectory(args...) = StaticTrajectory{EndPointTS}(args...) # default StaticTrajectory using last point from trajectory - function transition( rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, - τ::StaticTrajectory, + τ::Trajectory{TS, I, TC}, h::Hamiltonian, z::PhasePoint, -) +) where {TS<:AbstractTrajectorySampler, I, TC<:FixedNSteps} H0 = energy(z) integrator = jitter(rng, τ.integrator) @@ -291,7 +273,7 @@ end ### Use end-point from the trajectory as a proposal and apply MH correction -function sample_phasepoint(rng, τ::StaticTrajectory{EndPointTS}, h, z) +function sample_phasepoint(rng, τ::Trajectory{EndPointTS}, h, z) z′ = step(τ.integrator, h, z, τ.termination_criterion.L) is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′)) return z′, is_accept, α @@ -322,7 +304,7 @@ function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMat return z end -function sample_phasepoint(rng, τ::StaticTrajectory{MultinomialTS}, h, z) +function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) n_steps = abs(τ.termination_criterion.L) # TODO: Deal with vectorized-mode generically. # Currently the direction of multiple chains are always coupled @@ -345,50 +327,20 @@ function sample_phasepoint(rng, τ::StaticTrajectory{MultinomialTS}, h, z) return z′, true, α end -abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I} end - ### ### Standard HMC implementation with fixed total trajectory length. ### -""" -$(TYPEDEF) -Standard HMC implementation with fixed total trajectory length. - -# Fields -$(TYPEDFIELDS) - -# References -1. Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov chain Monte Carlo, 2(11), 2. ([arXiv](https://arxiv.org/pdf/1206.1901)) -""" -struct HMCDA{S<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:StaticTerminationCriterion} <: DynamicTrajectory{I} - integrator :: I - termination_criterion :: TC -end - -function Base.show(io::IO, τ::HMCDA{<:EndPointTS}) - print(io, "HMCDA{EndPointTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))") -end -function Base.show(io::IO, τ::HMCDA{<:MultinomialTS}) - print(io, "HMCDA{MultinomialTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))") -end - -function HMCDA{S}(integrator::I, λ::AbstractFloat) where {S,I} - tc = FixedIntegrationTime(λ) - return HMCDA{S, I, typeof(tc)}(integrator, tc) -end -HMCDA(args...) = HMCDA{EndPointTS}(args...) # default HMCDA using last point from trajectory - function transition( rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, - τ::HMCDA{S}, + τ::Trajectory{TS, I, TC}, h::Hamiltonian, z::PhasePoint, -) where {S} +) where {TS<:AbstractTrajectorySampler, I, TC<:FixedIntegrationTime} # Create the corresponding static τ n_steps = max(1, floor(Int, τ.termination_criterion.λ / nom_step_size(τ.integrator))) - static_τ = StaticTrajectory{S}(τ.integrator, n_steps) - return transition(rng, static_τ, h, z) + τ = Trajectory{TS}(τ.integrator, FixedNSteps(n_steps)) + return transition(rng, τ, h, z) end ### @@ -464,62 +416,6 @@ TurnStatistic(::Union{GeneralisedNoUTurn, StrictGeneralisedNoUTurn}, z::PhasePoi combine(ts::TurnStatistic{T}, ::TurnStatistic{T}) where {T<:UndefInitializer} = ts combine(tsl::T, tsr::T) where {T<:TurnStatistic} = TurnStatistic(tsl.rho + tsr.rho) -## -## NUTS -## - -""" -$(TYPEDEF) -Dynamic HMC algorithm using the no-U-turn termination criteria. - -# Fields -$(TYPEDFIELDS) -""" -struct NUTS{ - S<:AbstractTrajectorySampler, - C<:DynamicTerminationCriterion, - I<:AbstractIntegrator, -} <: DynamicTrajectory{I} - integrator :: I - termination_criterion :: C -end - -function Base.show(io::IO, τ::NUTS{<:SliceTS, <:ClassicNoUTurn}) - print(io, "NUTS{SliceTS}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))") -end -function Base.show(io::IO, τ::NUTS{<:SliceTS, <:GeneralisedNoUTurn}) - print(io, "NUTS{SliceTS,Generalised}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))") -end -function Base.show(io::IO, τ::NUTS{<:MultinomialTS, <:ClassicNoUTurn}) - print(io, "NUTS{MultinomialTS}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))") -end -function Base.show(io::IO, τ::NUTS{<:MultinomialTS, <:GeneralisedNoUTurn}) - print(io, "NUTS{MultinomialTS,Generalised}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))") -end - - -const NUTS_DOCSTR = """ - NUTS{S,C}( - integrator::I, - max_depth::Int=10, - Δ_max::F=1000.0 - ) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion} - -Create an instance for the No-U-Turn sampling algorithm. -""" - -"$(SIGNATURES)" -function NUTS{S,C}( - integrator::I, - max_depth::Int=10, - Δ_max::F=1000.0, -) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion} - return NUTS{S,C,I}(integrator, C(max_depth, Δ_max)) -end - -"$(SIGNATURES)" -NUTS(args...) = NUTS{MultinomialTS, GeneralisedNoUTurn}(args...) - ### ### The doubling tree algorithm for expanding trajectory. ### @@ -541,20 +437,20 @@ Base.:*(d1::Termination, d2::Termination) = Termination(d1.dynamic || d2.dynamic isterminated(d::Termination) = d.dynamic || d.numerical """ - Termination(s::SliceTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat} +$(SIGNATURES) Check termination of a Hamiltonian trajectory. """ -function Termination(s::SliceTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat} +function Termination(s::SliceTS, nt::Trajectory, H0::F, H′::F) where {F<:AbstractFloat} return Termination(false, !(s.ℓu < nt.termination_criterion.Δ_max + -H′)) end """ - Termination(s::MultinomialTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat} +$(SIGNATURES) Check termination of a Hamiltonian trajectory. """ -function Termination(s::MultinomialTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat} +function Termination(s::MultinomialTS, nt::Trajectory, H0::F, H′::F) where {F<:AbstractFloat} return Termination(false, !(-H0 < nt.termination_criterion.Δ_max + -H′)) end @@ -679,21 +575,21 @@ end "Recursivly build a tree for a given depth `j`." function build_tree( rng::AbstractRNG, - nt::NUTS{S,C,I}, + nt::Trajectory{TS, I, TC}, h::Hamiltonian, z::PhasePoint, sampler::AbstractTrajectorySampler, v::Int, j::Int, H0::AbstractFloat, -) where {I<:AbstractIntegrator,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion} +) where {TS<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:DynamicTerminationCriterion} if j == 0 # Base case - take one leapfrog step in the direction v. z′ = step(nt.integrator, h, z, v) H′ = energy(z′) ΔH = H′ - H0 α′ = exp(min(0, -ΔH)) - sampler′ = S(sampler, H0, z′) + sampler′ = TS(sampler, H0, z′) return BinaryTree(z′, z′, TurnStatistic(nt.termination_criterion, z′), α′, 1, ΔH), sampler′, Termination(sampler′, nt, H0, H′) else # Recursion - build the left and right subtrees. @@ -719,13 +615,13 @@ end function transition( rng::AbstractRNG, - τ::NUTS{S,C,I}, + τ::Trajectory{TS, I, TC}, h::Hamiltonian, z0::PhasePoint, -) where {I<:AbstractIntegrator,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion} +) where {TS<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:DynamicTerminationCriterion} H0 = energy(z0) tree = BinaryTree(z0, z0, TurnStatistic(τ.termination_criterion, z0), zero(H0), zero(Int), zero(H0)) - sampler = S(rng, z0) + sampler = TS(rng, z0) termination = Termination(false, false) zcand = z0 diff --git a/test/demo.jl b/test/demo.jl index b9ffdedc..546313f9 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -21,7 +21,7 @@ integrator = Leapfrog(initial_ϵ) # - multinomial sampling scheme, # - generalised No-U-Turn criteria, and # - windowed adaption for step-size and diagonal mass matrix -proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) +proposal = Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()) adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) # Run the sampler to draw samples from the specified Gaussian, where diff --git a/test/sampler-vec.jl b/test/sampler-vec.jl index 20fa5cfd..d75624d2 100644 --- a/test/sampler-vec.jl +++ b/test/sampler-vec.jl @@ -20,12 +20,12 @@ include("common.jl") DiagEuclideanMetric, # DenseEuclideanMetric # not supported at the moment ], τ in [ - StaticTrajectory{EndPointTS}(lfi, n_steps), - StaticTrajectory{MultinomialTS}(lfi, n_steps), - StaticTrajectory{EndPointTS}(lfi_jittered, n_steps), - StaticTrajectory{MultinomialTS}(lfi_jittered, n_steps), - HMCDA{EndPointTS}(lf, ϵ * n_steps), - HMCDA{MultinomialTS}(lf, ϵ * n_steps), + Trajectory{EndPointTS}(lfi, FixedNSteps(n_steps)), + Trajectory{MultinomialTS}(lfi, FixedNSteps(n_steps)), + Trajectory{EndPointTS}(lfi_jittered, FixedNSteps(n_steps)), + Trajectory{MultinomialTS}(lfi_jittered, FixedNSteps(n_steps)), + Trajectory{EndPointTS}(lf, FixedIntegrationTime(ϵ * n_steps)), + Trajectory{MultinomialTS}(lf, FixedIntegrationTime(ϵ * n_steps)), ] n_chains = n_chains_list[i_test] metric = metricT((D, n_chains)) @@ -52,7 +52,7 @@ include("common.jl") StepSizeAdaptor(0.8, lfi), ), ] - τ isa HMCDA && continue + τ.termination_criterion isa FixedIntegrationTime && continue @test show(adaptor) == nothing Random.seed!(100) @@ -73,7 +73,7 @@ include("common.jl") end @test all_same end - @info "Adaptation tests for HMCDA with StepSizeAdaptor are skipped" + @info "Adaptation tests for FixedIntegrationTime with StepSizeAdaptor are skipped" # Simple time benchmark let metricT=UnitEuclideanMetric diff --git a/test/sampler.jl b/test/sampler.jl index 7a1b613b..de776a7b 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -2,6 +2,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test, AdvancedHMC, LinearAlgebra, Random, MCMCDebugging, Plots +using AdvancedHMC: StaticTerminationCriterion, DynamicTerminationCriterion using Parameters: reconstruct using Statistics: mean, var, cov unicodeplots() @@ -13,7 +14,7 @@ n_steps = 10 n_samples = 22_000 n_adapts = 4_000 -function test_stats(::Union{StaticTrajectory,HMCDA}, stats, n_adapts) +function test_stats(::Trajectory{TS,I,TC}, stats, n_adapts) where {TS,I,TC<:StaticTerminationCriterion} for name in (:step_size, :nom_step_size, :n_steps, :is_accept, :acceptance_rate, :log_density, :hamiltonian_energy, :hamiltonian_energy_error, :is_adapt) @test all(map(s -> in(name, propertynames(s)), stats)) end @@ -22,7 +23,7 @@ function test_stats(::Union{StaticTrajectory,HMCDA}, stats, n_adapts) @test is_adapts[(n_adapts+1):end] == zeros(Bool, length(stats) - n_adapts) end -function test_stats(::NUTS, stats, n_adapts) +function test_stats(::Trajectory{TS,I,TC}, stats, n_adapts) where {TS,I,TC<:DynamicTerminationCriterion} for name in (:step_size, :nom_step_size, :n_steps, :is_accept, :acceptance_rate, :log_density, :hamiltonian_energy, :hamiltonian_energy_error, :is_adapt, :max_hamiltonian_energy_error, :tree_depth, :numerical_error) @test all(map(s -> in(name, propertynames(s)), stats)) end @@ -45,14 +46,14 @@ end :TemperedLeapfrog => TemperedLeapfrog(ϵ, 1.05), ) @testset "$τsym" for (τsym, τ) in Dict( - :(StaticTrajectory{EndPointTS}) => StaticTrajectory{EndPointTS}(lf, n_steps), - :(StaticTrajectory{MultinomialTS}) => StaticTrajectory{MultinomialTS}(lf, n_steps), - :(HMCDA{EndPointTS}) => HMCDA{EndPointTS}(lf, ϵ * n_steps), - :(HMCDA{MultinomialTS}) => HMCDA{MultinomialTS}(lf, ϵ * n_steps), - :(NUTS{SliceTS,Original}) => NUTS{SliceTS,ClassicNoUTurn}(lf), - :(NUTS{SliceTS,Generalised}) => NUTS{SliceTS,GeneralisedNoUTurn}(lf), - :(NUTS{MultinomialTS,Original}) => NUTS{MultinomialTS,ClassicNoUTurn}(lf), - :(NUTS{MultinomialTS,Generalised}) => NUTS{MultinomialTS,GeneralisedNoUTurn}(lf), + :(HMC{EndPointTS}) => Trajectory{EndPointTS}(lf, FixedNSteps(n_steps)), + :(HMC{MultinomialTS}) => Trajectory{MultinomialTS}(lf, FixedNSteps(n_steps)), + :(HMCDA{EndPointTS}) => Trajectory{EndPointTS}(lf, FixedIntegrationTime(ϵ * n_steps)), + :(HMCDA{MultinomialTS}) => Trajectory{MultinomialTS}(lf, FixedIntegrationTime(ϵ * n_steps)), + :(NUTS{SliceTS,Original}) => Trajectory{SliceTS}(lf, ClassicNoUTurn()), + :(NUTS{SliceTS,Generalised}) => Trajectory{SliceTS}(lf, GeneralisedNoUTurn()), + :(NUTS{MultinomialTS,Original}) => Trajectory{MultinomialTS}(lf, ClassicNoUTurn()), + :(NUTS{MultinomialTS,Generalised}) => Trajectory{MultinomialTS}(lf, GeneralisedNoUTurn()), ) @test show(h) == nothing @test show(τ) == nothing From fe71eb180b2c79a8bc79766d99b198dad64c5b87 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Feb 2021 19:47:57 +0000 Subject: [PATCH 2/8] introduce nsteps suggested by Tor --- src/trajectory.jl | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/trajectory.jl b/src/trajectory.jl index 775b3c46..68d1e795 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -199,6 +199,10 @@ function Base.show(io::IO, τ::Trajectory{TS}) where {TS} print(io, "Trajectory{$TS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))") end +nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedNSteps} = τ.termination_criterion.L +nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedIntegrationTime} = + max(1, floor(Int, τ.termination_criterion.λ / nom_step_size(τ.integrator))) + """ $(SIGNATURES) @@ -219,7 +223,7 @@ function transition( τ::Trajectory{TS, I, TC}, h::Hamiltonian, z::PhasePoint, -) where {TS<:AbstractTrajectorySampler, I, TC<:FixedNSteps} +) where {TS<:AbstractTrajectorySampler, I, TC<:StaticTerminationCriterion} H0 = energy(z) integrator = jitter(rng, τ.integrator) @@ -233,7 +237,7 @@ function transition( H = energy(z) tstat = merge( ( - n_steps=τ.termination_criterion.L, + n_steps=nsteps(τ), is_accept=is_accept, acceptance_rate=α, log_density=z.ℓπ.value, @@ -274,7 +278,7 @@ end ### Use end-point from the trajectory as a proposal and apply MH correction function sample_phasepoint(rng, τ::Trajectory{EndPointTS}, h, z) - z′ = step(τ.integrator, h, z, τ.termination_criterion.L) + z′ = step(τ.integrator, h, z, nsteps(τ)) is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′)) return z′, is_accept, α end @@ -305,7 +309,7 @@ function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMat end function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) - n_steps = abs(τ.termination_criterion.L) + n_steps = abs(nsteps(τ)) # TODO: Deal with vectorized-mode generically. # Currently the direction of multiple chains are always coupled n_steps_fwd = rand_coupled(rng, 0:n_steps) @@ -327,22 +331,6 @@ function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) return z′, true, α end -### -### Standard HMC implementation with fixed total trajectory length. -### - -function transition( - rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, - τ::Trajectory{TS, I, TC}, - h::Hamiltonian, - z::PhasePoint, -) where {TS<:AbstractTrajectorySampler, I, TC<:FixedIntegrationTime} - # Create the corresponding static τ - n_steps = max(1, floor(Int, τ.termination_criterion.λ / nom_step_size(τ.integrator))) - τ = Trajectory{TS}(τ.integrator, FixedNSteps(n_steps)) - return transition(rng, τ, h, z) -end - ### ### Advanced HMC implementation with (adaptive) dynamic trajectory length. ### From 5747c48e7f30e62da1490d394ffc44c1a3cea333 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Feb 2021 19:49:35 +0000 Subject: [PATCH 3/8] rename tests --- test/sampler.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/sampler.jl b/test/sampler.jl index de776a7b..29df7d1a 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -46,14 +46,14 @@ end :TemperedLeapfrog => TemperedLeapfrog(ϵ, 1.05), ) @testset "$τsym" for (τsym, τ) in Dict( - :(HMC{EndPointTS}) => Trajectory{EndPointTS}(lf, FixedNSteps(n_steps)), - :(HMC{MultinomialTS}) => Trajectory{MultinomialTS}(lf, FixedNSteps(n_steps)), - :(HMCDA{EndPointTS}) => Trajectory{EndPointTS}(lf, FixedIntegrationTime(ϵ * n_steps)), - :(HMCDA{MultinomialTS}) => Trajectory{MultinomialTS}(lf, FixedIntegrationTime(ϵ * n_steps)), - :(NUTS{SliceTS,Original}) => Trajectory{SliceTS}(lf, ClassicNoUTurn()), - :(NUTS{SliceTS,Generalised}) => Trajectory{SliceTS}(lf, GeneralisedNoUTurn()), - :(NUTS{MultinomialTS,Original}) => Trajectory{MultinomialTS}(lf, ClassicNoUTurn()), - :(NUTS{MultinomialTS,Generalised}) => Trajectory{MultinomialTS}(lf, GeneralisedNoUTurn()), + :(Trajectory{EndPointTS,FixedNSteps}) => Trajectory{EndPointTS}(lf, FixedNSteps(n_steps)), + :(Trajectory{MultinomialTS,FixedNSteps}) => Trajectory{MultinomialTS}(lf, FixedNSteps(n_steps)), + :(Trajectory{EndPointTS,FixedIntegrationTime}) => Trajectory{EndPointTS}(lf, FixedIntegrationTime(ϵ * n_steps)), + :(Trajectory{MultinomialTS,FixedIntegrationTime}) => Trajectory{MultinomialTS}(lf, FixedIntegrationTime(ϵ * n_steps)), + :(Trajectory{SliceTS,Original}) => Trajectory{SliceTS}(lf, ClassicNoUTurn()), + :(Trajectory{SliceTS,Generalised}) => Trajectory{SliceTS}(lf, GeneralisedNoUTurn()), + :(Trajectory{MultinomialTS,Original}) => Trajectory{MultinomialTS}(lf, ClassicNoUTurn()), + :(Trajectory{MultinomialTS,Generalised}) => Trajectory{MultinomialTS}(lf, GeneralisedNoUTurn()), ) @test show(h) == nothing @test show(τ) == nothing From 740aa93b0a741d0232c78cf40636e2974a61087e Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Feb 2021 19:58:47 +0000 Subject: [PATCH 4/8] update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fefcdd5c..5b8f34ab 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ integrator = Leapfrog(initial_ϵ) # - multinomial sampling scheme, # - generalised No-U-Turn criteria, and # - windowed adaption for step-size and diagonal mass matrix -proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) +proposal = Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()) adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) # Run the sampler to draw samples from the specified Gaussian, where From 0b6dc98fa00715a760b4faa1dfd049eecb58cb55 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 12 Feb 2021 21:15:27 +0000 Subject: [PATCH 5/8] support NUTS instead of deprecating it --- src/AdvancedHMC.jl | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 99812974..e3196a7c 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -40,6 +40,29 @@ export Trajectory, EndPointTS, SliceTS, MultinomialTS, find_good_stepsize +# Useful defaults + +""" +$(TYPEDEF) + +Convenient constructor for the no-U-turn sampler (NUTS). +This falls back to `Trajectory{TS}(int, TC(args...; kwargs...))` where + +- `TS<:Union{MultinomialTS, SliceTS}` is the type for trajectory sampler +- `TC<:Union{ClassicNoUTurn, GeneralisedNoUTurn, StrictGeneralisedNoUTurn}` is the type for termination criterion. + +See [`ClassicNoUTurn`](@ref), [`GeneralisedNoUTurn`](@ref) and [`StrictGeneralisedNoUTurn`](@ref) for details in parameters. +""" +struct NUTS{TS, TC} end +NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} = + Trajectory{TS}(int, TC(args...; kwargs...)) +NUTS(int::AbstractIntegrator, args...; kwargs...) = + Trajectory{MultinomialTS}(int, GeneralisedNoUTurn(args...; kwargs...)) +NUTS(ϵ::AbstractScalarOrVec{<:Real}) = + Trajectory{MultinomialTS}(Leapfrog(ϵ), GeneralisedNoUTurn()) + +export NUTS + # Deprecations for trajectory.jl struct AbstractTrajectory <: AbstractProposal end @@ -54,14 +77,9 @@ struct HMCDA{TS} end @deprecate HMCDA(int::AbstractIntegrator, λ) Trajectory{MetropolisTS}(int, FixedIntegrationTime(λ)) @deprecate HMCDA(ϵ::AbstractScalarOrVec{<:Real}, λ) Trajectory{MetropolisTS}(Leapfrog(ϵ), FixedIntegrationTime(λ)) -struct NUTS{TS, TC} end -@deprecate NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} Trajectory{TS}(int, TC(args...; kwargs...)) -@deprecate NUTS(int::AbstractIntegrator, args...; kwargs...) Trajectory{MultinomialTS}(int, GeneralisedNoUTurn(args...; kwargs...)) -@deprecate NUTS(ϵ::AbstractScalarOrVec{<:Real}) Trajectory{MultinomialTS}(Leapfrog(ϵ), GeneralisedNoUTurn()) - @deprecate find_good_eps find_good_stepsize -export AbstractTrajectory, StaticTrajectory, HMCDA, NUTS, find_good_eps +export AbstractTrajectory, StaticTrajectory, HMCDA, find_good_eps include("adaptation/Adaptation.jl") using .Adaptation From c8803cb4957b7f85dedb3f350b9b2bd21f38e512 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 12 Feb 2021 21:18:17 +0000 Subject: [PATCH 6/8] move comment to function instead of type def --- src/AdvancedHMC.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index e3196a7c..7473afeb 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -42,8 +42,10 @@ export Trajectory, # Useful defaults +struct NUTS{TS, TC} end + """ -$(TYPEDEF) +$(SIGNATURES) Convenient constructor for the no-U-turn sampler (NUTS). This falls back to `Trajectory{TS}(int, TC(args...; kwargs...))` where @@ -53,7 +55,6 @@ This falls back to `Trajectory{TS}(int, TC(args...; kwargs...))` where See [`ClassicNoUTurn`](@ref), [`GeneralisedNoUTurn`](@ref) and [`StrictGeneralisedNoUTurn`](@ref) for details in parameters. """ -struct NUTS{TS, TC} end NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} = Trajectory{TS}(int, TC(args...; kwargs...)) NUTS(int::AbstractIntegrator, args...; kwargs...) = From 892ca468f01d85ca53973abe741993e8242c83b3 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sun, 14 Feb 2021 22:58:59 +0000 Subject: [PATCH 7/8] use the old NUTS interface in demo --- README.md | 2 +- test/demo.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5b8f34ab..fefcdd5c 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ integrator = Leapfrog(initial_ϵ) # - multinomial sampling scheme, # - generalised No-U-Turn criteria, and # - windowed adaption for step-size and diagonal mass matrix -proposal = Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()) +proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) # Run the sampler to draw samples from the specified Gaussian, where diff --git a/test/demo.jl b/test/demo.jl index 546313f9..b9ffdedc 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -21,7 +21,7 @@ integrator = Leapfrog(initial_ϵ) # - multinomial sampling scheme, # - generalised No-U-Turn criteria, and # - windowed adaption for step-size and diagonal mass matrix -proposal = Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()) +proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) # Run the sampler to draw samples from the specified Gaussian, where From bc2e6d9459ea56521d1edda7b2477f6b5e404a2c Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sun, 14 Feb 2021 22:59:31 +0000 Subject: [PATCH 8/8] remove unnecessary exports --- src/AdvancedHMC.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 7473afeb..144224df 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -80,7 +80,7 @@ struct HMCDA{TS} end @deprecate find_good_eps find_good_stepsize -export AbstractTrajectory, StaticTrajectory, HMCDA, find_good_eps +export StaticTrajectory, HMCDA, find_good_eps include("adaptation/Adaptation.jl") using .Adaptation