Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No glue code #319

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 376 additions & 0 deletions Lab.ipynb

Large diffs are not rendered by default.

22 changes: 12 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ version = "0.4.6"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -19,6 +21,16 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

[extensions]
AdvancedHMCCUDAExt = "CUDA"
AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[compat]
AbstractMCMC = "4.2"
ArgCheck = "1, 2"
Expand All @@ -37,17 +49,7 @@ StatsBase = "0.31, 0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
julia = "1.6"

[extensions]
AdvancedHMCCUDAExt = "CUDA"
AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
30 changes: 5 additions & 25 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using LogDensityProblemsAD: LogDensityProblemsAD

import AbstractMCMC
using AbstractMCMC: LogDensityModel
using DynamicPPL

import StatsBase: sample

Expand Down Expand Up @@ -64,30 +65,6 @@ export Trajectory,
MultinomialTS,
find_good_stepsize

# Useful defaults

struct NUTS{TS,TC} end

"""
$(SIGNATURES)

Convenient constructor for the no-U-turn sampler (NUTS).
This falls back to `HMCKernel(Trajectory{TS}(int, TC(args...; kwargs...)))` where

- `TS<:Union{MultinomialTS, SliceTS}` is the type for trajectory sampler
- `TC<:Union{ClassicNoUTurn, GeneralisedNoUTurn, StrictGeneralisedNoUTurn}` is the type for termination criterion.

See [`ClassicNoUTurn`](@ref), [`GeneralisedNoUTurn`](@ref) and [`StrictGeneralisedNoUTurn`](@ref) for details in parameters.
"""
NUTS{TS,TC}(int::AbstractIntegrator, args...; kwargs...) where {TS,TC} =
HMCKernel(Trajectory{TS}(int, TC(args...; kwargs...)))
NUTS(int::AbstractIntegrator, args...; kwargs...) =
HMCKernel(Trajectory{MultinomialTS}(int, GeneralisedNoUTurn(args...; kwargs...)))
NUTS(ϵ::AbstractScalarOrVec{<:Real}) =
HMCKernel(Trajectory{MultinomialTS}(Leapfrog(ϵ), GeneralisedNoUTurn()))

export NUTS

# Deprecations for trajectory.jl

abstract type AbstractTrajectory end
Expand All @@ -103,6 +80,7 @@ struct StaticTrajectory{TS} end
Trajectory{EndPointTS}(Leapfrog(ϵ), FixedNSteps(L)),
)

#=
struct HMCDA{TS} end
@deprecate HMCDA{TS}(int::AbstractIntegrator, λ) where {TS} HMCKernel(
Trajectory{TS}(int, FixedIntegrationTime(λ)),
Expand All @@ -113,10 +91,11 @@ struct HMCDA{TS} end
@deprecate HMCDA(ϵ::AbstractScalarOrVec{<:Real}, λ) HMCKernel(
Trajectory{EndPointTS}(Leapfrog(ϵ), FixedIntegrationTime(λ)),
)
=#

@deprecate find_good_eps find_good_stepsize

export StaticTrajectory, HMCDA, find_good_eps
export StaticTrajectory, find_good_eps #HMCDA,

include("adaptation/Adaptation.jl")
using .Adaptation
Expand Down Expand Up @@ -168,6 +147,7 @@ include("diagnosis.jl")
include("sampler.jl")
export sample

include("constructors.jl")
include("abstractmcmc.jl")

## Without explicit AD backend
Expand Down
191 changes: 44 additions & 147 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,3 @@
"""
HMCSampler

A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl.

# Fields

$(FIELDS)

# Notes

Note that all the fields have the prefix `initial_` to indicate
that these will not necessarily correspond to the `kernel`, `metric`,
and `adaptor` after sampling.

To access the updated fields use the resulting [`HMCState`](@ref).
"""
struct HMCSampler{K,M,A} <: AbstractMCMC.AbstractSampler
"Initial [`AbstractMCMCKernel`](@ref)."
initial_kernel::K
"Initial [`AbstractMetric`](@ref)."
initial_metric::M
"Initial [`AbstractAdaptor`](@ref)."
initial_adaptor::A
end
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation())

"""
HMCState

Expand Down Expand Up @@ -53,148 +26,72 @@ struct HMCState{
adaptor::TAdapt
end

"""
$(TYPEDSIGNATURES)

A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref).
"""
function AbstractMCMC.sample(
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
N::Integer;
kwargs...,
)
return AbstractMCMC.sample(
Random.GLOBAL_RNG,
model,
kernel,
metric,
adaptor,
N;
kwargs...,
)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
N::Integer;
progress = true,
verbose = false,
callback = nothing,
function AbstractMCMC.step(
rng::AbstractRNG,
model,#::DynamicPPL.model,
spl::AbstractMCMC.AbstractSampler;
init_params = nothing,
kwargs...,
)
sampler = HMCSampler(kernel, metric, adaptor)
if callback === nothing
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
progress = false # don't use AMCMC's progress-funtionality
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
)
)

vi = kwargs[:vi]
d = kwargs[:d]
n_adapts = spl.n_adapts

# We will need to implement this but it is going to be
# Interesting how to plug the transforms along the sampling
# processes
# vi_t = Turing.link!!(vi, model)

# Define metric
if spl.metric == nothing
metric = DiagEuclideanMetric(d)
else
metric = spl.metric
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric = spl.metric
metric = spl.metric

end

return AbstractMCMC.mcmcsample(
rng,
model,
sampler,
N;
progress = progress,
verbose = verbose,
callback = callback,
kwargs...,
)
end

function AbstractMCMC.sample(
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...,
)
return AbstractMCMC.sample(
Random.GLOBAL_RNG,
model,
kernel,
metric,
adaptor,
N,
nchains;
kwargs...,
)
end
# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model)

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
progress = true,
verbose = false,
callback = nothing,
kwargs...,
)
sampler = HMCSampler(kernel, metric, adaptor)
if callback === nothing
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
progress = false # don't use AMCMC's progress-funtionality
# Find good eps if not provided one
# Before it was spl.alg.ϵ to allow prior sampling
if iszero(spl.ϵ)
# Extract parameters.
theta = vi[spl]
ϵ = find_good_stepsize(rng, hamiltonian, theta)
println(string("Found initial step size ", ϵ))
else
ϵ = spl.ϵ
end

return AbstractMCMC.mcmcsample(
rng,
model,
sampler,
parallel,
N,
nchains;
progress = progress,
verbose = verbose,
callback = callback,
kwargs...,
)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::LogDensityModel,
spl::HMCSampler;
init_params = nothing,
kwargs...,
)
metric = spl.initial_metric
κ = spl.initial_kernel
adaptor = spl.initial_adaptor
integrator = spl.integrator(ϵ)
kernel = spl.kernel(integrator)
adaptor = spl.adaptor(metric, integrator)
spl = HMCSampler(kernel, metric, adaptor)

if init_params === nothing
init_params = randn(rng, size(metric, 1))
end

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params)

# Compute next transition and state.
state = HMCState(0, t, h.metric, κ, adaptor)
state = HMCState(0, t, h.metric, kernel, adaptor)

# Take actual first step.
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts = n_adapts,
kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts = n_adapts,
kwargs...)
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
kwargs...)
kwargs...,
)

end

function AbstractMCMC.step(
rng::AbstractRNG,
model::LogDensityModel,
spl::HMCSampler,
spl::AbstractMCMC.AbstractSampler,
state::HMCState;
nadapts::Int = 0,
kwargs...,
Expand Down
Loading