-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No glue code #319
No glue code #319
Changes from 14 commits
43cfe65
3158bb6
37d6831
6fe6436
304d401
46f0803
2f6f2c1
dfd5e74
3038441
deb0555
00b837a
ce96cac
1f8c5a7
612e10b
3bbc668
1bffe99
b941529
8b1f962
0f45cc8
3e1c403
8fa9fcb
c582abf
39941fd
3684a1e
62c2096
1893e3f
1af622c
acaa289
80f0d8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,30 +1,3 @@ | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
HMCSampler | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Fields | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
$(FIELDS) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Notes | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Note that all the fields have the prefix `initial_` to indicate | ||||||||||||||||||||||||
that these will not necessarily correspond to the `kernel`, `metric`, | ||||||||||||||||||||||||
and `adaptor` after sampling. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
To access the updated fields use the resulting [`HMCState`](@ref). | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
struct HMCSampler{K,M,A} <: AbstractMCMC.AbstractSampler | ||||||||||||||||||||||||
"Initial [`AbstractMCMCKernel`](@ref)." | ||||||||||||||||||||||||
initial_kernel::K | ||||||||||||||||||||||||
"Initial [`AbstractMetric`](@ref)." | ||||||||||||||||||||||||
initial_metric::M | ||||||||||||||||||||||||
"Initial [`AbstractAdaptor`](@ref)." | ||||||||||||||||||||||||
initial_adaptor::A | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation()) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
HMCState | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -53,148 +26,72 @@ struct HMCState{ | |||||||||||||||||||||||
adaptor::TAdapt | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
$(TYPEDSIGNATURES) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref). | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
function AbstractMCMC.sample( | ||||||||||||||||||||||||
model::LogDensityModel, | ||||||||||||||||||||||||
kernel::AbstractMCMCKernel, | ||||||||||||||||||||||||
metric::AbstractMetric, | ||||||||||||||||||||||||
adaptor::AbstractAdaptor, | ||||||||||||||||||||||||
N::Integer; | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return AbstractMCMC.sample( | ||||||||||||||||||||||||
Random.GLOBAL_RNG, | ||||||||||||||||||||||||
model, | ||||||||||||||||||||||||
kernel, | ||||||||||||||||||||||||
metric, | ||||||||||||||||||||||||
adaptor, | ||||||||||||||||||||||||
N; | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
function AbstractMCMC.sample( | ||||||||||||||||||||||||
rng::Random.AbstractRNG, | ||||||||||||||||||||||||
model::LogDensityModel, | ||||||||||||||||||||||||
kernel::AbstractMCMCKernel, | ||||||||||||||||||||||||
metric::AbstractMetric, | ||||||||||||||||||||||||
adaptor::AbstractAdaptor, | ||||||||||||||||||||||||
N::Integer; | ||||||||||||||||||||||||
progress = true, | ||||||||||||||||||||||||
verbose = false, | ||||||||||||||||||||||||
callback = nothing, | ||||||||||||||||||||||||
function AbstractMCMC.step( | ||||||||||||||||||||||||
rng::AbstractRNG, | ||||||||||||||||||||||||
model,#::DynamicPPL.model, | ||||||||||||||||||||||||
spl::AbstractMCMC.AbstractSampler; | ||||||||||||||||||||||||
init_params = nothing, | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
sampler = HMCSampler(kernel, metric, adaptor) | ||||||||||||||||||||||||
if callback === nothing | ||||||||||||||||||||||||
callback = HMCProgressCallback(N, progress = progress, verbose = verbose) | ||||||||||||||||||||||||
progress = false # don't use AMCMC's progress-funtionality | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
vi = kwargs[:vi] | ||||||||||||||||||||||||
d = kwargs[:d] | ||||||||||||||||||||||||
n_adapts = spl.n_adapts | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# We will need to implement this but it is going to be | ||||||||||||||||||||||||
# Interesting how to plug the transforms along the sampling | ||||||||||||||||||||||||
# processes | ||||||||||||||||||||||||
# vi_t = Turing.link!!(vi, model) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Define metric | ||||||||||||||||||||||||
if spl.metric == nothing | ||||||||||||||||||||||||
metric = DiagEuclideanMetric(d) | ||||||||||||||||||||||||
else | ||||||||||||||||||||||||
metric = spl.metric | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return AbstractMCMC.mcmcsample( | ||||||||||||||||||||||||
rng, | ||||||||||||||||||||||||
model, | ||||||||||||||||||||||||
sampler, | ||||||||||||||||||||||||
N; | ||||||||||||||||||||||||
progress = progress, | ||||||||||||||||||||||||
verbose = verbose, | ||||||||||||||||||||||||
callback = callback, | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
function AbstractMCMC.sample( | ||||||||||||||||||||||||
model::LogDensityModel, | ||||||||||||||||||||||||
kernel::AbstractMCMCKernel, | ||||||||||||||||||||||||
metric::AbstractMetric, | ||||||||||||||||||||||||
adaptor::AbstractAdaptor, | ||||||||||||||||||||||||
parallel::AbstractMCMC.AbstractMCMCEnsemble, | ||||||||||||||||||||||||
N::Integer, | ||||||||||||||||||||||||
nchains::Integer; | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return AbstractMCMC.sample( | ||||||||||||||||||||||||
Random.GLOBAL_RNG, | ||||||||||||||||||||||||
model, | ||||||||||||||||||||||||
kernel, | ||||||||||||||||||||||||
metric, | ||||||||||||||||||||||||
adaptor, | ||||||||||||||||||||||||
N, | ||||||||||||||||||||||||
nchains; | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
# Construct the hamiltonian using the initial metric | ||||||||||||||||||||||||
hamiltonian = Hamiltonian(metric, model) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
function AbstractMCMC.sample( | ||||||||||||||||||||||||
rng::Random.AbstractRNG, | ||||||||||||||||||||||||
model::LogDensityModel, | ||||||||||||||||||||||||
kernel::AbstractMCMCKernel, | ||||||||||||||||||||||||
metric::AbstractMetric, | ||||||||||||||||||||||||
adaptor::AbstractAdaptor, | ||||||||||||||||||||||||
parallel::AbstractMCMC.AbstractMCMCEnsemble, | ||||||||||||||||||||||||
N::Integer, | ||||||||||||||||||||||||
nchains::Integer; | ||||||||||||||||||||||||
progress = true, | ||||||||||||||||||||||||
verbose = false, | ||||||||||||||||||||||||
callback = nothing, | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
sampler = HMCSampler(kernel, metric, adaptor) | ||||||||||||||||||||||||
if callback === nothing | ||||||||||||||||||||||||
callback = HMCProgressCallback(N, progress = progress, verbose = verbose) | ||||||||||||||||||||||||
progress = false # don't use AMCMC's progress-funtionality | ||||||||||||||||||||||||
# Find good eps if not provided one | ||||||||||||||||||||||||
# Before it was spl.alg.ϵ to allow prior sampling | ||||||||||||||||||||||||
if iszero(spl.ϵ) | ||||||||||||||||||||||||
# Extract parameters. | ||||||||||||||||||||||||
theta = vi[spl] | ||||||||||||||||||||||||
ϵ = find_good_stepsize(rng, hamiltonian, theta) | ||||||||||||||||||||||||
println(string("Found initial step size ", ϵ)) | ||||||||||||||||||||||||
else | ||||||||||||||||||||||||
ϵ = spl.ϵ | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return AbstractMCMC.mcmcsample( | ||||||||||||||||||||||||
rng, | ||||||||||||||||||||||||
model, | ||||||||||||||||||||||||
sampler, | ||||||||||||||||||||||||
parallel, | ||||||||||||||||||||||||
N, | ||||||||||||||||||||||||
nchains; | ||||||||||||||||||||||||
progress = progress, | ||||||||||||||||||||||||
verbose = verbose, | ||||||||||||||||||||||||
callback = callback, | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
function AbstractMCMC.step( | ||||||||||||||||||||||||
rng::AbstractRNG, | ||||||||||||||||||||||||
model::LogDensityModel, | ||||||||||||||||||||||||
spl::HMCSampler; | ||||||||||||||||||||||||
init_params = nothing, | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
metric = spl.initial_metric | ||||||||||||||||||||||||
κ = spl.initial_kernel | ||||||||||||||||||||||||
adaptor = spl.initial_adaptor | ||||||||||||||||||||||||
integrator = spl.integrator(ϵ) | ||||||||||||||||||||||||
kernel = spl.kernel(integrator) | ||||||||||||||||||||||||
adaptor = spl.adaptor(metric, integrator) | ||||||||||||||||||||||||
spl = HMCSampler(kernel, metric, adaptor) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if init_params === nothing | ||||||||||||||||||||||||
init_params = randn(rng, size(metric, 1)) | ||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Construct the hamiltonian using the initial metric | ||||||||||||||||||||||||
hamiltonian = Hamiltonian(metric, model) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Get an initial sample. | ||||||||||||||||||||||||
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Compute next transition and state. | ||||||||||||||||||||||||
state = HMCState(0, t, h.metric, κ, adaptor) | ||||||||||||||||||||||||
state = HMCState(0, t, h.metric, kernel, adaptor) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Take actual first step. | ||||||||||||||||||||||||
return AbstractMCMC.step(rng, model, spl, state; kwargs...) | ||||||||||||||||||||||||
return AbstractMCMC.step( | ||||||||||||||||||||||||
rng, | ||||||||||||||||||||||||
model, | ||||||||||||||||||||||||
spl, | ||||||||||||||||||||||||
state; | ||||||||||||||||||||||||
n_adapts = n_adapts, | ||||||||||||||||||||||||
kwargs...) | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||
end | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
function AbstractMCMC.step( | ||||||||||||||||||||||||
rng::AbstractRNG, | ||||||||||||||||||||||||
model::LogDensityModel, | ||||||||||||||||||||||||
spl::HMCSampler, | ||||||||||||||||||||||||
spl::AbstractMCMC.AbstractSampler, | ||||||||||||||||||||||||
state::HMCState; | ||||||||||||||||||||||||
nadapts::Int = 0, | ||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶