-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No glue code #319
No glue code #319
Conversation
src/sampler.jl
Outdated
function sample(model::DynamicPPL.Model, ϵ::Number, TAP::Number, n_samples::Int, n_adapts::Int; | ||
initial_θ=initial_θ, progress=true, kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function sample(model::DynamicPPL.Model, ϵ::Number, TAP::Number, n_samples::Int, n_adapts::Int; | |
initial_θ=initial_θ, progress=true, kwargs...) | |
function sample( | |
model::DynamicPPL.Model, | |
ϵ::Number, | |
TAP::Number, | |
n_samples::Int, | |
n_adapts::Int; | |
initial_θ = initial_θ, | |
progress = true, | |
kwargs..., | |
) |
src/sampler.jl
Outdated
vsyms = _name_variables(vi, dist_lengths) | ||
d = length(vsyms) | ||
|
||
metric = DiagEuclideanMetric(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric = DiagEuclideanMetric(d) | |
metric = DiagEuclideanMetric(d) |
src/sampler.jl
Outdated
|
||
metric = DiagEuclideanMetric(d) | ||
integrator = Leapfrog(ϵ) | ||
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) | |
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) |
src/sampler.jl
Outdated
return sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts; | ||
progress=progress, kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts; | |
progress=progress, kwargs...) | |
return sample( | |
model, | |
metric, | |
proposal, | |
initial_θ, | |
n_samples, | |
adaptor, | |
n_adapts; | |
progress = progress, | |
kwargs..., | |
) |
src/turing_utils.jl
Outdated
vsyms = keys(vi) | ||
names = [] | ||
for (vsym, dist_length) in zip(vsyms, dist_lengths) | ||
if dist_length==1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
if dist_length==1 | |
if dist_length == 1 |
src/turing_utils.jl
Outdated
name = [vsym] | ||
append!(names, name) | ||
else | ||
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length] | |
name = [DynamicPPL.VarName(Symbol(vsym, i)) for i = 1:dist_length] |
src/turing_utils.jl
Outdated
else | ||
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length] | ||
append!(names, name) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/turing_utils.jl
Outdated
end | ||
end | ||
return names | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
I have now coded the changes into AdvancedHMC src. n_samples, n_adapts = 10_000, 1_000
sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts) or even simpler samples, stats = sample(model, 0.1, 0.95, n_samples, n_adapts; initial_θ=initial_θ) which will use NUTS by defautl. The unbounded to bounded space transforms are still missing. |
src/abstractmcmc.jl
Outdated
d = length(vsyms) | ||
|
||
# wrap metric, kernel and adaptor into HMCSampler | ||
metric = DiagEuclideanMetric(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric = DiagEuclideanMetric(d) | |
metric = DiagEuclideanMetric(d) |
src/abstractmcmc.jl
Outdated
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) | ||
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(settings.TAP, integrator)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) | |
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(settings.TAP, integrator)) | |
kernel = AdvancedHMC.NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) | |
adaptor = | |
StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(settings.TAP, integrator)) |
src/abstractmcmc.jl
Outdated
@@ -25,6 +25,15 @@ struct HMCSampler{K,M,A} <: AbstractMCMC.AbstractSampler | |||
end | |||
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation()) | |||
|
|||
# Convinience constructor | |||
function NUTSSampler(ϵ::Float64, TAP::Float64, d::Int) | |||
metric = DiagEuclideanMetric(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric = DiagEuclideanMetric(d) | |
metric = DiagEuclideanMetric(d) |
src/abstractmcmc.jl
Outdated
function NUTSSampler(ϵ::Float64, TAP::Float64, d::Int) | ||
metric = DiagEuclideanMetric(d) | ||
integrator = Leapfrog(ϵ) | ||
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) | |
kernel = AdvancedHMC.NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) |
src/abstractmcmc.jl
Outdated
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) | ||
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator)) | ||
return HMCSampler(kernel, metric, adaptor) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/abstractmcmc.jl
Outdated
# No glue code # | ||
################ | ||
function AbstractMCMC.sample( | ||
model::DynamicPPL.Model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
model::DynamicPPL.Model, | |
model::DynamicPPL.Model, |
src/abstractmcmc.jl
Outdated
progress = true, | ||
verbose = false, | ||
callback = nothing, | ||
kwargs..., | ||
) | ||
sampler = HMCSampler(kernel, metric, adaptor) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
) | |
) |
src/abstractmcmc.jl
Outdated
ℓ = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi, model, ctxt)) | ||
d = LogDensityProblems.dimension(ℓ) | ||
model = AbstractMCMC.LogDensityModel(ℓ) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
||
if init_params === nothing | ||
init_params = randn(rng, size(metric, 1)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
) | |
) |
src/abstractmcmc.jl
Outdated
if spl.metric == nothing | ||
metric = DiagEuclideanMetric(d) | ||
else | ||
metric = spl.metric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric = spl.metric | |
metric = spl.metric |
src/abstractmcmc.jl
Outdated
return AbstractMCMC.step( | ||
rng, | ||
model, | ||
spl, | ||
state; | ||
n_adapts = n_adapts, | ||
kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return AbstractMCMC.step( | |
rng, | |
model, | |
spl, | |
state; | |
n_adapts = n_adapts, | |
kwargs...) | |
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...) |
src/constructors.jl
Outdated
metric | ||
integrator | ||
kernel | ||
adaptor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric | |
integrator | |
kernel | |
adaptor | |
metric::Any | |
integrator::Any | |
kernel::Any | |
adaptor::Any |
src/constructors.jl
Outdated
max_depth::Int=10, | ||
Δ_max::Float64=1000.0, | ||
init_ϵ::Float64=0.0, | ||
metric=nothing, | ||
integrator=Leapfrog, | ||
kernel = NUTS_kernel{MultinomialTS, GeneralisedNoUTurn} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
max_depth::Int=10, | |
Δ_max::Float64=1000.0, | |
init_ϵ::Float64=0.0, | |
metric=nothing, | |
integrator=Leapfrog, | |
kernel = NUTS_kernel{MultinomialTS, GeneralisedNoUTurn} | |
) | |
max_depth::Int = 10, | |
Δ_max::Float64 = 1000.0, | |
init_ϵ::Float64 = 0.0, | |
metric = nothing, | |
integrator = Leapfrog, | |
kernel = NUTS_kernel{MultinomialTS,GeneralisedNoUTurn}, | |
) |
src/constructors.jl
Outdated
return StanHMCAdaptor(MassMatrixAdaptor(metric), | ||
StepSizeAdaptor(TAP, integrator)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return StanHMCAdaptor(MassMatrixAdaptor(metric), | |
StepSizeAdaptor(TAP, integrator)) | |
end | |
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator)) | |
end |
StepSizeAdaptor(TAP, integrator)) | ||
end | ||
NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, kernel, adaptor) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/constructors.jl
Outdated
|
||
function NUTS_kernel(integrator) | ||
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/constructors.jl
Outdated
max_depth::Int=10, | ||
Δ_max::Float64=1000.0, | ||
init_ϵ::Float64=0.0, | ||
metric=nothing, | ||
integrator=Leapfrog) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
max_depth::Int=10, | |
Δ_max::Float64=1000.0, | |
init_ϵ::Float64=0.0, | |
metric=nothing, | |
integrator=Leapfrog) | |
max_depth::Int = 10, | |
Δ_max::Float64 = 1000.0, | |
init_ϵ::Float64 = 0.0, | |
metric = nothing, | |
integrator = Leapfrog, | |
) |
src/constructors.jl
Outdated
NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor) | ||
end | ||
|
||
export NUTS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
export NUTS | |
export NUTS |
src/constructors.jl
Outdated
function kernel(integrator) | ||
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(n_leapfrog))) | ||
end | ||
return kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return kernel | |
return kernel |
src/constructors.jl
Outdated
metric | ||
integrator | ||
kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric | |
integrator | |
kernel | |
metric::Any | |
integrator::Any | |
kernel::Any |
src/constructors.jl
Outdated
function kernel(integrator) | ||
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedIntegrationTime(λ))) | ||
end | ||
return kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return kernel | |
return kernel |
src/constructors.jl
Outdated
n_adapts :: Int # number of samples with adaption for ϵ | ||
TAP :: Float64 # target accept rate | ||
λ :: Float64 # target leapfrog length | ||
ϵ :: Float64 # (initial) step size | ||
metric | ||
integrator | ||
kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
n_adapts :: Int # number of samples with adaption for ϵ | |
TAP :: Float64 # target accept rate | |
λ :: Float64 # target leapfrog length | |
ϵ :: Float64 # (initial) step size | |
metric | |
integrator | |
kernel | |
n_adapts::Int # number of samples with adaption for ϵ | |
TAP::Float64 # target accept rate | |
λ::Float64 # target leapfrog length | |
ϵ::Float64 # (initial) step size | |
metric::Any | |
integrator::Any | |
kernel::Any |
src/constructors.jl
Outdated
ϵ::Float64=0.0, | ||
metric=nothing, | ||
integrator=Leapfrog) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
ϵ::Float64=0.0, | |
metric=nothing, | |
integrator=Leapfrog) | |
ϵ::Float64 = 0.0, | |
metric = nothing, | |
integrator = Leapfrog, | |
) |
src/constructors.jl
Outdated
return StanHMCAdaptor(MassMatrixAdaptor(metric), | ||
StepSizeAdaptor(TAP, integrator)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return StanHMCAdaptor(MassMatrixAdaptor(metric), | |
StepSizeAdaptor(TAP, integrator)) | |
end | |
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator)) | |
end |
src/constructors.jl
Outdated
return HMCDA(n_adapts, TAP, λ, ϵ, metric, integrator, kernel, adaptor) | ||
end | ||
|
||
export HMCDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
export HMCDA | |
export HMCDA |
src/abstractmcmc.jl
Outdated
progress = false # don't use AMCMC's progress-funtionality | ||
) | ||
vi = kwargs[:vi] | ||
d = kwargs[:d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
d = kwargs[:d] | |
d = kwargs[:d] |
src/constructors.jl
Outdated
metric | ||
integrator | ||
kernel | ||
adaptor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric | |
integrator | |
kernel | |
adaptor | |
metric::Any | |
integrator::Any | |
kernel::Any | |
adaptor::Any |
src/constructors.jl
Outdated
n_adapts :: Int # number of samples with adaption for ϵ | ||
TAP :: Float64 # target accept rate | ||
λ :: Float64 # target leapfrog length | ||
ϵ :: Float64 # (initial) step size | ||
metric | ||
integrator | ||
kernel | ||
adaptor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
n_adapts :: Int # number of samples with adaption for ϵ | |
TAP :: Float64 # target accept rate | |
λ :: Float64 # target leapfrog length | |
ϵ :: Float64 # (initial) step size | |
metric | |
integrator | |
kernel | |
adaptor | |
n_adapts::Int # number of samples with adaption for ϵ | |
TAP::Float64 # target accept rate | |
λ::Float64 # target leapfrog length | |
ϵ::Float64 # (initial) step size | |
metric::Any | |
integrator::Any | |
kernel::Any | |
adaptor::Any |
@@ -246,4 +246,4 @@ function sample( | |||
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h κ EBFMI_est average_acceptance_rate | |||
end | |||
return θs, stats | |||
end | |||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/abstractmcmc.jl
Outdated
spl, | ||
state; | ||
n_adapts = n_adapts, | ||
kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
kwargs...) | |
kwargs..., | |
) |
src/abstractmcmc.jl
Outdated
state::HMCState; | ||
nadapts::Int = 0, | ||
kwargs..., | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
) | |
) |
src/abstractmcmc.jl
Outdated
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f)) | ||
function getdimensions(f::AbstractMCMC.LogDensityModel) | ||
return LogDensityProblems.dimension(f.logdensity) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/abstractmcmc.jl
Outdated
println(typeof(hamiltonian)<:Hamiltonian) | ||
return AbstractMCMC.step( | ||
rng, | ||
model, | ||
spl, | ||
state; | ||
n_adapts = n_adapts, | ||
kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
println(typeof(hamiltonian)<:Hamiltonian) | |
return AbstractMCMC.step( | |
rng, | |
model, | |
spl, | |
state; | |
n_adapts = n_adapts, | |
kwargs...) | |
println(typeof(hamiltonian) <: Hamiltonian) | |
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...) |
src/constructors.jl
Outdated
return StanHMCAdaptor(MassMatrixAdaptor(metric), | ||
StepSizeAdaptor(TAP, integrator)) | ||
end | ||
AHMC_NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return StanHMCAdaptor(MassMatrixAdaptor(metric), | |
StepSizeAdaptor(TAP, integrator)) | |
end | |
AHMC_NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor) | |
end | |
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator)) | |
end | |
AHMC_NUTS( | |
n_adapts, | |
TAP, | |
max_depth, | |
Δ_max, | |
init_ϵ, | |
metric, | |
integrator, | |
NUTS_kernel, | |
adaptor, | |
) | |
end |
src/constructors.jl
Outdated
AHMC_NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor) | ||
end | ||
|
||
export AHMC_NUTS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
export AHMC_NUTS | |
export AHMC_NUTS |
src/constructors.jl
Outdated
return AHMC_HMCDA(n_adapts, TAP, λ, ϵ, metric, integrator, kernel, adaptor) | ||
end | ||
|
||
export AHMC_HMCDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
export AHMC_HMCDA | |
export AHMC_HMCDA |
src/abstractmcmc.jl
Outdated
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), | ||
StepSizeAdaptor(spl.alg.TAP, integrator)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), | |
StepSizeAdaptor(spl.alg.TAP, integrator)) | |
adaptor = StanHMCAdaptor( | |
MassMatrixAdaptor(metric), | |
StepSizeAdaptor(spl.alg.TAP, integrator), | |
) |
src/abstractmcmc.jl
Outdated
else | ||
adaptor = spl.adaptor | ||
n_adapts = kwargs[:n_adapts] | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/constructors.jl
Outdated
# Basic use | ||
HMCSampler(algorithm) = HMCSampler(algorithm, nothing, nothing, nothing, nothing) | ||
# Expert use | ||
HMCSampler(integrator, kernel, metric, adaptor) = HMCSampler(Custom_alg, integrator, kernel, metric, adaptor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
HMCSampler(integrator, kernel, metric, adaptor) = HMCSampler(Custom_alg, integrator, kernel, metric, adaptor) | |
HMCSampler(integrator, kernel, metric, adaptor) = | |
HMCSampler(Custom_alg, integrator, kernel, metric, adaptor) |
src/constructors.jl
Outdated
########## | ||
# Custom # | ||
########## | ||
struct Custom_alg<:SamplingAlgorithm end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
struct Custom_alg<:SamplingAlgorithm end | |
struct Custom_alg <: SamplingAlgorithm end |
src/constructors.jl
Outdated
max_depth::Int=10, | ||
Δ_max::Float64=1000.0, | ||
ϵ::Float64=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
max_depth::Int=10, | |
Δ_max::Float64=1000.0, | |
ϵ::Float64=0.0) | |
max_depth::Int = 10, | |
Δ_max::Float64 = 1000.0, | |
ϵ::Float64 = 0.0, | |
) |
src/constructors.jl
Outdated
function HMCDA( | ||
n_adapts::Int, | ||
TAP::Float64, | ||
λ::Float64; | ||
ϵ::Float64=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function HMCDA( | |
n_adapts::Int, | |
TAP::Float64, | |
λ::Float64; | |
ϵ::Float64=0.0) | |
function HMCDA(n_adapts::Int, TAP::Float64, λ::Float64; ϵ::Float64 = 0.0) |
src/constructors.jl
Outdated
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator), | ||
StepSizeAdaptor(alg.TAP, integrator)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator), | |
StepSizeAdaptor(alg.TAP, integrator)) | |
end | |
return StanHMCAdaptor( | |
MassMatrixAdaptor(metric, integrator), | |
StepSizeAdaptor(alg.TAP, integrator), | |
) | |
end |
|
||
function make_kernel(alg::NUTS_alg, integrator) | ||
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
|
||
function make_kernel(alg::HMC_alg, integrator) | ||
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(alg.n_leapfrog))) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
|
||
function make_kernel(alg::HMCDA_alg, integrator) | ||
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedIntegrationTime(alg.λ))) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/abstractmcmc.jl
Outdated
return AbstractMCMC.step( | ||
rng, | ||
model, | ||
spl, | ||
state; | ||
n_adapts=n_adapts, | ||
kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return AbstractMCMC.step( | |
rng, | |
model, | |
spl, | |
state; | |
n_adapts=n_adapts, | |
kwargs...) | |
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...) |
src/abstractmcmc.jl
Outdated
state::HMCState; | ||
nadapts::Int = 0, | ||
nadapts::Int=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
nadapts::Int=0, | |
nadapts::Int = 0, |
src/abstractmcmc.jl
Outdated
@@ -217,14 +143,19 @@ function AbstractMCMC.step( | |||
h, κ, isadapted = adapt!(h, κ, adaptor, i, nadapts, t.z.θ, tstat.acceptance_rate) | |||
tstat = merge(tstat, (is_adapt = isadapted,)) | |||
|
|||
# Convert variables back | |||
vii_t = DynamicPPL.unflatten(vi_t, t.z.θ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
vii_t = DynamicPPL.unflatten(vi_t, t.z.θ) | |
vii_t = DynamicPPL.unflatten(vi_t, t.z.θ) |
src/abstractmcmc.jl
Outdated
vi = DynamicPPL.VarInfo(model, ctxt) | ||
vi_t = DynamicPPL.link!!(vi, model) | ||
logdensityfunction = DynamicPPL.LogDensityFunction(vi_t, model, ctxt) | ||
logdensityproblem = LogDensityProblemsAD.ADgradient(logdensityfunction) | ||
logdensitymodel = AbstractMCMC.LogDensityModel(logdensityproblem) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we should be independent of DynamicPPL APIs or internals from MCMC libraries. Maybe this should be transferred to DynamicPPL/AbstractMCMC.
src/abstractmcmc.jl
Outdated
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), | ||
StepSizeAdaptor(spl.alg.δ, integrator)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), | |
StepSizeAdaptor(spl.alg.δ, integrator)) | |
adaptor = StanHMCAdaptor( | |
MassMatrixAdaptor(metric), | |
StepSizeAdaptor(spl.alg.δ, integrator), | |
) |
state = HMCState(0, t, h.metric, κ, adaptor) | ||
|
||
state = HMCState(0, t, metric, κ, adaptor) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@@ -302,4 +312,4 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw | |||
elseif verbose && isadapted && i == nadapts | |||
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric | |||
end | |||
end | |||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
src/constructors.jl
Outdated
n_adapts :: Int # number of samples with adaption for ϵ | ||
δ :: Float64 # target accept rate | ||
λ :: Float64 # target leapfrog length | ||
ϵ :: Float64 # (initial) step size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
n_adapts :: Int # number of samples with adaption for ϵ | |
δ :: Float64 # target accept rate | |
λ :: Float64 # target leapfrog length | |
ϵ :: Float64 # (initial) step size | |
n_adapts::Int # number of samples with adaption for ϵ | |
δ::Float64 # target accept rate | |
λ::Float64 # target leapfrog length | |
ϵ::Float64 # (initial) step size |
src/constructors.jl
Outdated
function HMCDA( | ||
n_adapts::Int, | ||
δ::Float64, | ||
λ::Float64; | ||
ϵ::Float64=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function HMCDA( | |
n_adapts::Int, | |
δ::Float64, | |
λ::Float64; | |
ϵ::Float64=0.0) | |
function HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0) |
src/constructors.jl
Outdated
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator), | ||
StepSizeAdaptor(alg.δ, integrator)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator), | |
StepSizeAdaptor(alg.δ, integrator)) | |
end | |
return StanHMCAdaptor( | |
MassMatrixAdaptor(metric, integrator), | |
StepSizeAdaptor(alg.δ, integrator), | |
) | |
end |
src/constructors.jl
Outdated
Base.@kwdef struct HMCSampler{I,K,M,A} <: AbstractMCMC.AbstractSampler | ||
alg::SamplingAlgorithm | ||
"[`integrator`](@ref)." | ||
integrator::I=nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
integrator::I=nothing | |
integrator::I = nothing |
"[`integrator`](@ref)." | ||
integrator::I=nothing | ||
"[`AbstractMCMCKernel`](@ref)." | ||
kernel::K=nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
kernel::K=nothing | |
kernel::K = nothing |
"[`AbstractMCMCKernel`](@ref)." | ||
kernel::K=nothing | ||
"[`AbstractMetric`](@ref)." | ||
metric::M=nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
metric::M=nothing | |
metric::M = nothing |
"[`AbstractMetric`](@ref)." | ||
metric::M=nothing | ||
"[`AbstractAdaptor`](@ref)." | ||
adaptor::A=nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
adaptor::A=nothing | |
adaptor::A = nothing |
src/constructors.jl
Outdated
function HMCDA( | ||
n_adapts::Int, | ||
δ::Float64, | ||
λ::Float64; | ||
ϵ::Float64=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function HMCDA( | |
n_adapts::Int, | |
δ::Float64, | |
λ::Float64; | |
ϵ::Float64=0.0) | |
function HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0) |
src/constructors.jl
Outdated
To access the updated fields use the resulting [`HMCState`](@ref). | ||
""" | ||
Base.@kwdef struct HMCSampler{I,K,M,A} <: AbstractMCMC.AbstractSampler | ||
alg::HMCAlgorithm=Custom_alg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
alg::HMCAlgorithm=Custom_alg | |
alg::HMCAlgorithm = Custom_alg |
src/constructors.jl
Outdated
########## | ||
# Custom # | ||
########## | ||
struct Custom_alg<:HMCAlgorithm end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
struct Custom_alg<:HMCAlgorithm end | |
struct Custom_alg <: HMCAlgorithm end |
src/constructors.jl
Outdated
max_depth::Int=10, | ||
Δ_max::Float64=1000.0, | ||
ϵ::Float64=0.0) | ||
return HMCSampler(;alg=NUTS_alg(n_adapts, δ, max_depth, Δ_max, ϵ)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
max_depth::Int=10, | |
Δ_max::Float64=1000.0, | |
ϵ::Float64=0.0) | |
return HMCSampler(;alg=NUTS_alg(n_adapts, δ, max_depth, Δ_max, ϵ)) | |
end | |
max_depth::Int = 10, | |
Δ_max::Float64 = 1000.0, | |
ϵ::Float64 = 0.0, | |
) | |
return HMCSampler(; alg = NUTS_alg(n_adapts, δ, max_depth, Δ_max, ϵ)) | |
end |
src/constructors.jl
Outdated
ϵ::Float64, | ||
n_leapfrog::Int) | ||
|
||
return HMCSampler(;alg=HMC_alg(ϵ, n_leapfrog)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return HMCSampler(;alg=HMC_alg(ϵ, n_leapfrog)) | |
return HMCSampler(; alg = HMC_alg(ϵ, n_leapfrog)) |
src/constructors.jl
Outdated
function HMCDA( | ||
n_adapts::Int, | ||
δ::Float64, | ||
λ::Float64; | ||
ϵ::Float64=0.0) | ||
return HMCSampler(;alg=HMCDA_alg(n_adapts, δ, λ, ϵ)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function HMCDA( | |
n_adapts::Int, | |
δ::Float64, | |
λ::Float64; | |
ϵ::Float64=0.0) | |
return HMCSampler(;alg=HMCDA_alg(n_adapts, δ, λ, ϵ)) | |
function HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0) | |
return HMCSampler(; alg = HMCDA_alg(n_adapts, δ, λ, ϵ)) |
src/abstractmcmc.jl
Outdated
d = d=LogDensityProblems.dimension(logdensity) | ||
metric = make_metric(spl; d=d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
d = d=LogDensityProblems.dimension(logdensity) | |
metric = make_metric(spl; d=d) | |
d = d = LogDensityProblems.dimension(logdensity) | |
metric = make_metric(spl; d = d) |
src/abstractmcmc.jl
Outdated
integrator = make_integrator(spl; | ||
rng=rng, | ||
hamiltonian=hamiltonian, | ||
init_params=init_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
integrator = make_integrator(spl; | |
rng=rng, | |
hamiltonian=hamiltonian, | |
init_params=init_params) | |
integrator = make_integrator( | |
spl; | |
rng = rng, | |
hamiltonian = hamiltonian, | |
init_params = init_params, | |
) |
@@ -0,0 +1,204 @@ | |||
abstract type AbstractHMCSampler <:AbstractMCMC.AbstractSampler end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
abstract type AbstractHMCSampler <:AbstractMCMC.AbstractSampler end | |
abstract type AbstractHMCSampler <: AbstractMCMC.AbstractSampler end |
""" | ||
Base.@kwdef struct CustomHMC{I,K,M,A} <: AbstractMCMC.AbstractSampler | ||
"[`integrator`](@ref)." | ||
integrator::I=Leapfrog |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
integrator::I=Leapfrog | |
integrator::I = Leapfrog |
max_depth::Int=10 # maximum tree depth | ||
Δ_max::Float64=1000.0 # maximum error | ||
init_ϵ::Float64=0.0 # (initial) step size | ||
integrator_method=Leapfrog # integrator method | ||
metric_type=DiagEuclideanMetric # metric type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
max_depth::Int=10 # maximum tree depth | |
Δ_max::Float64=1000.0 # maximum error | |
init_ϵ::Float64=0.0 # (initial) step size | |
integrator_method=Leapfrog # integrator method | |
metric_type=DiagEuclideanMetric # metric type | |
max_depth::Int = 10 # maximum tree depth | |
Δ_max::Float64 = 1000.0 # maximum error | |
init_ϵ::Float64 = 0.0 # (initial) step size | |
integrator_method = Leapfrog # integrator method | |
metric_type = DiagEuclideanMetric # metric type |
function make_adaptor(spl::Union{NUTS_alg, HMCDA_alg}, metric, integrator) | ||
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), | ||
StepSizeAdaptor(spl.δ, integrator)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function make_adaptor(spl::Union{NUTS_alg, HMCDA_alg}, metric, integrator) | |
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), | |
StepSizeAdaptor(spl.δ, integrator)) | |
function make_adaptor(spl::Union{NUTS_alg,HMCDA_alg}, metric, integrator) | |
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(spl.δ, integrator)) |
StepSizeAdaptor(spl.δ, integrator)) | ||
n_adapts = spl.n_adapts | ||
return n_adapts, adaptor | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
|
||
function make_adaptor(spl::HMC_alg, metric, integrator) | ||
return 0, NoAdaptation() | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
return 0, NoAdaptation() | ||
end | ||
|
||
function make_adaptor(spl::CustomHMC, metric, integrator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function make_adaptor(spl::CustomHMC, metric, integrator) | |
function make_adaptor(spl::CustomHMC, metric, integrator) |
|
||
function make_adaptor(spl::CustomHMC, metric, integrator) | ||
return spl.n_adapts, spl.adaptor | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
function make_integrator(rng, spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, | ||
hamiltonian, init_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function make_integrator(rng, spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, | |
hamiltonian, init_params) | |
function make_integrator( | |
rng, | |
spl::Union{HMC_alg,NUTS_alg,HMCDA_alg}, | |
hamiltonian, | |
init_params, | |
) |
|
||
######### | ||
|
||
function make_metric(spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, logdensity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function make_metric(spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, logdensity) | |
function make_metric(spl::Union{HMC_alg,NUTS_alg,HMCDA_alg}, logdensity) |
Hi All!
My name is Jaime Ruiz Zapatero and I am 3rd year PhD in astrophysics at Oxford.
@yebai and I have been discussing possible ways of interfacing Turing and AdvancedHMC in a more general without the current glue code currently present inside Turing.
This motivated by the limitations of the current interface already discussed in this PR by @sethaxen.
The fundamental idea is to extract a
LogDensityProblem
object from a generic Turing model which then is used to build aHamiltonian
object for AdvancedHMC.This can be done as follows:
Where model is a conditioned Turing model.
I have written a Neal's funnel example of how this interface can work which you can find in this notebook. Here I have also drafted some place holder functions to wrap the interface such that it is user-friendly. Essentially, I have added an additional signature to
sample
and created a wrapper structure for the sampler ingridients. To do so I have tried to follow the idea proposed in this issue by @yebai.This example is already fully functional. However, it is missing an important aspect. Turing models can have priors with hard boundaries. However, samplers work best in continuous spaces. The solution is map the bounded prior space to the unbounded sampling space. Turing does this internally but to avoid relying too much on Turing one can also do the following:
Then the generated samples can be transformed back to the prior space using:
Note that these already account for the jacobian of the transformation. Also, this might not be the most elegant way of doing the transformation.
Question: Where would you suggest writing these functions into the code? My guess would be within sampler but I don't want to mess around with the design idea.
Question: Where would you suggest incorporating the transformation functions?
Once we have interfaced with the internal sampling method, we should also be able to use
AbstractMCMC
to do the sample following what Turing does. I already have written something like this for a micro-canonical HMC sampler I have been working on, MicrocanonicalHMC.jl. The most involved step is to overload theAbstractMCMC.step
function with AdvancedHMC equivalent.All the best,
Jaime