Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying trajectories #245

Merged
merged 8 commits into from
Feb 15, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ integrator = Leapfrog(initial_ϵ)
# - multinomial sampling scheme,
# - generalised No-U-Turn criteria, and
# - windowed adaption for step-size and diagonal mass matrix
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
proposal = Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())
yebai marked this conversation as resolved.
Show resolved Hide resolved
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

# Run the sampler to draw samples from the specified Gaussian, where
Expand Down
51 changes: 46 additions & 5 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,54 @@ include("integrator.jl")
export Leapfrog, JitteredLeapfrog, TemperedLeapfrog

include("trajectory.jl")
@deprecate find_good_eps find_good_stepsize
export EndPointTS, SliceTS, MultinomialTS,
StaticTrajectory, HMCDA, NUTS,
ClassicNoUTurn, GeneralisedNoUTurn,
StrictGeneralisedNoUTurn,
export Trajectory,
FixedNSteps, FixedIntegrationTime,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider unifying these type names in a future PR, e.g. FixedNSteps ==> FixedIntegrationSteps.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue for this and some other potential naming improvements #246.

ClassicNoUTurn, GeneralisedNoUTurn, StrictGeneralisedNoUTurn,
EndPointTS, SliceTS, MultinomialTS,
find_good_stepsize

# Useful defaults

struct NUTS{TS, TC} end

"""
$(SIGNATURES)

Convenient constructor for the no-U-turn sampler (NUTS).
This falls back to `Trajectory{TS}(int, TC(args...; kwargs...))` where

- `TS<:Union{MultinomialTS, SliceTS}` is the type for trajectory sampler
- `TC<:Union{ClassicNoUTurn, GeneralisedNoUTurn, StrictGeneralisedNoUTurn}` is the type for termination criterion.

See [`ClassicNoUTurn`](@ref), [`GeneralisedNoUTurn`](@ref) and [`StrictGeneralisedNoUTurn`](@ref) for details in parameters.
"""
NUTS{TS, TC}(int::AbstractIntegrator, args...; kwargs...) where {TS, TC} =
Trajectory{TS}(int, TC(args...; kwargs...))
NUTS(int::AbstractIntegrator, args...; kwargs...) =
Trajectory{MultinomialTS}(int, GeneralisedNoUTurn(args...; kwargs...))
NUTS(ϵ::AbstractScalarOrVec{<:Real}) =
Trajectory{MultinomialTS}(Leapfrog(ϵ), GeneralisedNoUTurn())

export NUTS

# Deprecations for trajectory.jl

struct AbstractTrajectory <: AbstractProposal end

struct StaticTrajectory{TS} end
@deprecate StaticTrajectory{TS}(int::AbstractIntegrator, L) where {TS} Trajectory{TS}(int, FixedNSteps(L))
@deprecate StaticTrajectory(int::AbstractIntegrator, L) Trajectory{EndPointTS}(int, FixedNSteps(L))
@deprecate StaticTrajectory(ϵ::AbstractScalarOrVec{<:Real}, L) Trajectory{EndPointTS}(Leapfrog(ϵ), FixedNSteps(L))

struct HMCDA{TS} end
@deprecate HMCDA{TS}(int::AbstractIntegrator, λ) where {TS} Trajectory{TS}(int, FixedIntegrationTime(λ))
@deprecate HMCDA(int::AbstractIntegrator, λ) Trajectory{MetropolisTS}(int, FixedIntegrationTime(λ))
@deprecate HMCDA(ϵ::AbstractScalarOrVec{<:Real}, λ) Trajectory{MetropolisTS}(Leapfrog(ϵ), FixedIntegrationTime(λ))

@deprecate find_good_eps find_good_stepsize

export AbstractTrajectory, StaticTrajectory, HMCDA, find_good_eps
yebai marked this conversation as resolved.
Show resolved Hide resolved

include("adaptation/Adaptation.jl")
using .Adaptation
import .Adaptation: StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging
Expand Down
198 changes: 41 additions & 157 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ struct FixedIntegrationTime{F<:AbstractFloat} <: StaticTerminationCriterion
λ::F
end

"Hamiltonian dynamics numerical simulation trajectories."
abstract type AbstractTrajectory{I<:AbstractIntegrator} <: AbstractProposal end

##
## Sampling methods for trajectories.
##
Expand Down Expand Up @@ -184,60 +181,49 @@ function mh_accept(rng::AbstractRNG, s::MultinomialTS, s′::MultinomialTS)
end

"""
$(SIGNATURES)

Make a MCMC transition from phase point `z` using the trajectory `τ` under Hamiltonian `h`.
$(TYPEDEF)

NOTE: This is a RNG-implicit fallback function for `transition(GLOBAL_RNG, τ, h, z)`
Numerically simulated Hamiltonian trajectories.
"""
function transition(τ::AbstractTrajectory, h::Hamiltonian, z::PhasePoint)
return transition(GLOBAL_RNG, τ, h, z)
struct Trajectory{TS<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:AbstractTerminationCriterion} <: AbstractProposal
"Integrator used to simulate trajectory."
integrator::I
"Criterion to terminate the simulation."
termination_criterion::TC
end

###
### Actual trajectory implementations
###
Trajectory{TS}(integrator::I, termination_criterion::TC) where {TS, I, TC} =
Trajectory{TS, I, TC}(integrator, termination_criterion)

"""
$(TYPEDEF)
Static HMC with a fixed number of leapfrog steps.
function Base.show(io::IO, τ::Trajectory{TS}) where {TS}
print(io, "Trajectory{$TS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))")
end

# Fields
$(TYPEDFIELDS)
nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedNSteps} = τ.termination_criterion.L
nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedIntegrationTime} =
max(1, floor(Int, τ.termination_criterion.λ / nom_step_size(τ.integrator)))

# References
1. Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov chain Monte Carlo, 2(11), 2. ([arXiv](https://arxiv.org/pdf/1206.1901))
"""
struct StaticTrajectory{
S<:AbstractTrajectorySampler,
I<:AbstractIntegrator,
TC<:StaticTerminationCriterion
} <: AbstractTrajectory{I}
"Integrator used to simulate trajectory."
integrator :: I
termination_criterion :: TC
end
$(SIGNATURES)

function Base.show(io::IO, τ::StaticTrajectory{<:EndPointTS})
print(io, "StaticTrajectory{EndPointTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))")
end
Make a MCMC transition from phase point `z` using the trajectory `τ` under Hamiltonian `h`.

function Base.show(io::IO, τ::StaticTrajectory{<:MultinomialTS})
print(io, "StaticTrajectory{MultinomialTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))")
NOTE: This is a RNG-implicit fallback function for `transition(GLOBAL_RNG, τ, h, z)`
"""
function transition(τ::Trajectory, h::Hamiltonian, z::PhasePoint)
return transition(GLOBAL_RNG, τ, h, z)
end

function StaticTrajectory{S}(integrator::I, n_steps::Int) where {S,I}
tc = FixedNSteps(n_steps)
return StaticTrajectory{S, I, typeof(tc)}(integrator, tc)
end
StaticTrajectory(args...) = StaticTrajectory{EndPointTS}(args...) # default StaticTrajectory using last point from trajectory
###
### Actual trajectory implementations
###

function transition(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
τ::StaticTrajectory,
τ::Trajectory{TS, I, TC},
h::Hamiltonian,
z::PhasePoint,
)
) where {TS<:AbstractTrajectorySampler, I, TC<:StaticTerminationCriterion}
H0 = energy(z)

integrator = jitter(rng, τ.integrator)
Expand All @@ -251,7 +237,7 @@ function transition(
H = energy(z)
tstat = merge(
(
n_steps=τ.termination_criterion.L,
n_steps=nsteps(τ),
is_accept=is_accept,
acceptance_rate=α,
log_density=z.ℓπ.value,
Expand Down Expand Up @@ -291,8 +277,8 @@ end

### Use end-point from the trajectory as a proposal and apply MH correction

function sample_phasepoint(rng, τ::StaticTrajectory{EndPointTS}, h, z)
z′ = step(τ.integrator, h, z, τ.termination_criterion.L)
function sample_phasepoint(rng, τ::Trajectory{EndPointTS}, h, z)
z′ = step(τ.integrator, h, z, nsteps(τ))
is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′))
return z′, is_accept, α
end
Expand Down Expand Up @@ -322,8 +308,8 @@ function randcat(rng, zs::AbstractVector{<:PhasePoint}, unnorm_ℓP::AbstractMat
return z
end

function sample_phasepoint(rng, τ::StaticTrajectory{MultinomialTS}, h, z)
n_steps = abs(τ.termination_criterion.L)
function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z)
n_steps = abs(nsteps(τ))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Deal with vectorized-mode generically.
# Currently the direction of multiple chains are always coupled
n_steps_fwd = rand_coupled(rng, 0:n_steps)
Expand All @@ -345,52 +331,6 @@ function sample_phasepoint(rng, τ::StaticTrajectory{MultinomialTS}, h, z)
return z′, true, α
end

abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I} end

###
### Standard HMC implementation with fixed total trajectory length.
###

"""
$(TYPEDEF)
Standard HMC implementation with fixed total trajectory length.

# Fields
$(TYPEDFIELDS)

# References
1. Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov chain Monte Carlo, 2(11), 2. ([arXiv](https://arxiv.org/pdf/1206.1901))
"""
struct HMCDA{S<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:StaticTerminationCriterion} <: DynamicTrajectory{I}
integrator :: I
termination_criterion :: TC
end

function Base.show(io::IO, τ::HMCDA{<:EndPointTS})
print(io, "HMCDA{EndPointTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))")
end
function Base.show(io::IO, τ::HMCDA{<:MultinomialTS})
print(io, "HMCDA{MultinomialTS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))")
end

function HMCDA{S}(integrator::I, λ::AbstractFloat) where {S,I}
tc = FixedIntegrationTime(λ)
return HMCDA{S, I, typeof(tc)}(integrator, tc)
end
HMCDA(args...) = HMCDA{EndPointTS}(args...) # default HMCDA using last point from trajectory

function transition(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
τ::HMCDA{S},
h::Hamiltonian,
z::PhasePoint,
) where {S}
# Create the corresponding static τ
n_steps = max(1, floor(Int, τ.termination_criterion.λ / nom_step_size(τ.integrator)))
static_τ = StaticTrajectory{S}(τ.integrator, n_steps)
return transition(rng, static_τ, h, z)
end

###
### Advanced HMC implementation with (adaptive) dynamic trajectory length.
###
Expand Down Expand Up @@ -464,62 +404,6 @@ TurnStatistic(::Union{GeneralisedNoUTurn, StrictGeneralisedNoUTurn}, z::PhasePoi
combine(ts::TurnStatistic{T}, ::TurnStatistic{T}) where {T<:UndefInitializer} = ts
combine(tsl::T, tsr::T) where {T<:TurnStatistic} = TurnStatistic(tsl.rho + tsr.rho)

##
## NUTS
##

"""
$(TYPEDEF)
Dynamic HMC algorithm using the no-U-turn termination criteria.

# Fields
$(TYPEDFIELDS)
"""
struct NUTS{
S<:AbstractTrajectorySampler,
C<:DynamicTerminationCriterion,
I<:AbstractIntegrator,
} <: DynamicTrajectory{I}
integrator :: I
termination_criterion :: C
end

function Base.show(io::IO, τ::NUTS{<:SliceTS, <:ClassicNoUTurn})
print(io, "NUTS{SliceTS}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))")
end
function Base.show(io::IO, τ::NUTS{<:SliceTS, <:GeneralisedNoUTurn})
print(io, "NUTS{SliceTS,Generalised}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))")
end
function Base.show(io::IO, τ::NUTS{<:MultinomialTS, <:ClassicNoUTurn})
print(io, "NUTS{MultinomialTS}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))")
end
function Base.show(io::IO, τ::NUTS{<:MultinomialTS, <:GeneralisedNoUTurn})
print(io, "NUTS{MultinomialTS,Generalised}(integrator=$(τ.integrator), termination_criterion=$(τ.termination_criterion))")
end


const NUTS_DOCSTR = """
NUTS{S,C}(
integrator::I,
max_depth::Int=10,
Δ_max::F=1000.0
) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion}

Create an instance for the No-U-Turn sampling algorithm.
"""

"$(SIGNATURES)"
function NUTS{S,C}(
integrator::I,
max_depth::Int=10,
Δ_max::F=1000.0,
) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion}
return NUTS{S,C,I}(integrator, C(max_depth, Δ_max))
end

"$(SIGNATURES)"
NUTS(args...) = NUTS{MultinomialTS, GeneralisedNoUTurn}(args...)

###
### The doubling tree algorithm for expanding trajectory.
###
Expand All @@ -541,20 +425,20 @@ Base.:*(d1::Termination, d2::Termination) = Termination(d1.dynamic || d2.dynamic
isterminated(d::Termination) = d.dynamic || d.numerical

"""
Termination(s::SliceTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat}
$(SIGNATURES)

Check termination of a Hamiltonian trajectory.
"""
function Termination(s::SliceTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat}
function Termination(s::SliceTS, nt::Trajectory, H0::F, H′::F) where {F<:AbstractFloat}
return Termination(false, !(s.ℓu < nt.termination_criterion.Δ_max + -H′))
end

"""
Termination(s::MultinomialTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat}
$(SIGNATURES)

Check termination of a Hamiltonian trajectory.
"""
function Termination(s::MultinomialTS, nt::NUTS, H0::F, H′::F) where {F<:AbstractFloat}
function Termination(s::MultinomialTS, nt::Trajectory, H0::F, H′::F) where {F<:AbstractFloat}
return Termination(false, !(-H0 < nt.termination_criterion.Δ_max + -H′))
end

Expand Down Expand Up @@ -679,21 +563,21 @@ end
"Recursivly build a tree for a given depth `j`."
function build_tree(
rng::AbstractRNG,
nt::NUTS{S,C,I},
nt::Trajectory{TS, I, TC},
h::Hamiltonian,
z::PhasePoint,
sampler::AbstractTrajectorySampler,
v::Int,
j::Int,
H0::AbstractFloat,
) where {I<:AbstractIntegrator,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion}
) where {TS<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:DynamicTerminationCriterion}
if j == 0
# Base case - take one leapfrog step in the direction v.
z′ = step(nt.integrator, h, z, v)
H′ = energy(z′)
ΔH = H′ - H0
α′ = exp(min(0, -ΔH))
sampler′ = S(sampler, H0, z′)
sampler′ = TS(sampler, H0, z′)
return BinaryTree(z′, z′, TurnStatistic(nt.termination_criterion, z′), α′, 1, ΔH), sampler′, Termination(sampler′, nt, H0, H′)
else
# Recursion - build the left and right subtrees.
Expand All @@ -719,13 +603,13 @@ end

function transition(
rng::AbstractRNG,
τ::NUTS{S,C,I},
τ::Trajectory{TS, I, TC},
h::Hamiltonian,
z0::PhasePoint,
) where {I<:AbstractIntegrator,S<:AbstractTrajectorySampler,C<:DynamicTerminationCriterion}
) where {TS<:AbstractTrajectorySampler, I<:AbstractIntegrator, TC<:DynamicTerminationCriterion}
H0 = energy(z0)
tree = BinaryTree(z0, z0, TurnStatistic(τ.termination_criterion, z0), zero(H0), zero(Int), zero(H0))
sampler = S(rng, z0)
sampler = TS(rng, z0)
termination = Termination(false, false)
zcand = z0

Expand Down
2 changes: 1 addition & 1 deletion test/demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ integrator = Leapfrog(initial_ϵ)
# - multinomial sampling scheme,
# - generalised No-U-Turn criteria, and
# - windowed adaption for step-size and diagonal mass matrix
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
proposal = Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

# Run the sampler to draw samples from the specified Gaussian, where
Expand Down
Loading