Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No glue code #319

Closed
wants to merge 29 commits into from
Closed

No glue code #319

wants to merge 29 commits into from

Conversation

JaimeRZP
Copy link
Member

@JaimeRZP JaimeRZP commented Mar 9, 2023

Hi All!

My name is Jaime Ruiz Zapatero and I am 3rd year PhD in astrophysics at Oxford.

@yebai and I have been discussing possible ways of interfacing Turing and AdvancedHMC in a more general without the current glue code currently present inside Turing.
This motivated by the limitations of the current interface already discussed in this PR by @sethaxen.
The fundamental idea is to extract a LogDensityProblem object from a generic Turing model which then is used to build a Hamiltonian object for AdvancedHMC.
This can be done as follows:

ctxt = model.context
vi = DynamicPPL.VarInfo(model, ctxt)
ℓ = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi, model, ctxt))
hamiltonian = AdvancedHMC.Hamiltonian(metric, ℓ)

Where model is a conditioned Turing model.

I have written a Neal's funnel example of how this interface can work which you can find in this notebook. Here I have also drafted some place holder functions to wrap the interface such that it is user-friendly. Essentially, I have added an additional signature to sample and created a wrapper structure for the sampler ingridients. To do so I have tried to follow the idea proposed in this issue by @yebai.

This example is already fully functional. However, it is missing an important aspect. Turing models can have priors with hard boundaries. However, samplers work best in continuous spaces. The solution is map the bounded prior space to the unbounded sampling space. Turing does this internally but to avoid relying too much on Turing one can also do the following:

ctxt = model.context
vi = DynamicPPL.VarInfo(model, ctxt)
vi_t = Turing.link!!(vi, model) # This transforms the variables to the unbounded space

# By passing vi_t as opposed to vi to ℓ, the log density function will input and output the transformed variables= LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi_t, model, ctxt))
hamiltonian = AdvancedHMC.Hamiltonian(metric, ℓ)

Then the generated samples can be transformed back to the prior space using:

function _get_dists(vi)
    mds = values(vi.metadata)
    return [md.dists[1] for md in mds]
end

dists = _get_dists(vi)
dist_lengths = [length(dist) for dist in dists]
vsyms = _name_variables(vi, dist_lengths)

function _reshape_params(x::AbstractVector)
    xx = []
    idx = 0
    for dist_length in dist_lengths
        append!(xx, [x[idx+1:idx+dist_length]])
        idx += dist_length
    end
    return xx
end

function transform(x)
    x = _reshape_params(x)
    xt = [Bijectors.link(dist, par) for (dist, par) in zip(dists, x)]
    return vcat(xt...)
end

function inv_transform(xt)
    xt = _reshape_params(xt)
    x = [Bijectors.invlink(dist, par) for (dist, par) in zip(dists, xt)]
    return vcat(x...)
end

Note that these already account for the jacobian of the transformation. Also, this might not be the most elegant way of doing the transformation.

Question: Where would you suggest writing these functions into the code? My guess would be within sampler but I don't want to mess around with the design idea.

Question: Where would you suggest incorporating the transformation functions?

Once we have interfaced with the internal sampling method, we should also be able to use AbstractMCMC to do the sample following what Turing does. I already have written something like this for a micro-canonical HMC sampler I have been working on, MicrocanonicalHMC.jl. The most involved step is to overload the AbstractMCMC.step function with AdvancedHMC equivalent.

All the best,
Jaime

@JaimeRZP JaimeRZP linked an issue Mar 9, 2023 that may be closed by this pull request
src/sampler.jl Outdated
Comment on lines 176 to 177
function sample(model::DynamicPPL.Model, ϵ::Number, TAP::Number, n_samples::Int, n_adapts::Int;
initial_θ=initial_θ, progress=true, kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function sample(model::DynamicPPL.Model, ϵ::Number, TAP::Number, n_samples::Int, n_adapts::Int;
initial_θ=initial_θ, progress=true, kwargs...)
function sample(
model::DynamicPPL.Model,
ϵ::Number,
TAP::Number,
n_samples::Int,
n_adapts::Int;
initial_θ = initial_θ,
progress = true,
kwargs...,
)

src/sampler.jl Outdated
vsyms = _name_variables(vi, dist_lengths)
d = length(vsyms)

metric = DiagEuclideanMetric(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric = DiagEuclideanMetric(d)
metric = DiagEuclideanMetric(d)

src/sampler.jl Outdated

metric = DiagEuclideanMetric(d)
integrator = Leapfrog(ϵ)
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)

src/sampler.jl Outdated
Comment on lines 190 to 191
return sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts;
progress=progress, kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts;
progress=progress, kwargs...)
return sample(
model,
metric,
proposal,
initial_θ,
n_samples,
adaptor,
n_adapts;
progress = progress,
kwargs...,
)

vsyms = keys(vi)
names = []
for (vsym, dist_length) in zip(vsyms, dist_lengths)
if dist_length==1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if dist_length==1
if dist_length == 1

name = [vsym]
append!(names, name)
else
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length]
name = [DynamicPPL.VarName(Symbol(vsym, i)) for i = 1:dist_length]

else
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length]
append!(names, name)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

end
end
return names
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

@JaimeRZP
Copy link
Member Author

I have now coded the changes into AdvancedHMC src.
Users should be able to sample a Turing model using AdvancedHMC directly by using:

n_samples, n_adapts = 10_000, 1_000
sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts)

or even simpler

samples, stats = sample(model, 0.1, 0.95, n_samples, n_adapts; initial_θ=initial_θ)

which will use NUTS by defautl.

The unbounded to bounded space transforms are still missing.

@JaimeRZP JaimeRZP mentioned this pull request May 30, 2023
5 tasks
d = length(vsyms)

# wrap metric, kernel and adaptor into HMCSampler
metric = DiagEuclideanMetric(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric = DiagEuclideanMetric(d)
metric = DiagEuclideanMetric(d)

Comment on lines 106 to 107
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(settings.TAP, integrator))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(settings.TAP, integrator))
kernel = AdvancedHMC.NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor =
StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(settings.TAP, integrator))

@@ -25,6 +25,15 @@ struct HMCSampler{K,M,A} <: AbstractMCMC.AbstractSampler
end
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation())

# Convinience constructor
function NUTSSampler(ϵ::Float64, TAP::Float64, d::Int)
metric = DiagEuclideanMetric(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric = DiagEuclideanMetric(d)
metric = DiagEuclideanMetric(d)

function NUTSSampler(ϵ::Float64, TAP::Float64, d::Int)
metric = DiagEuclideanMetric(d)
integrator = Leapfrog(ϵ)
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
kernel = AdvancedHMC.NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)

kernel = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator))
return HMCSampler(kernel, metric, adaptor)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

# No glue code #
################
function AbstractMCMC.sample(
model::DynamicPPL.Model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
model::DynamicPPL.Model,
model::DynamicPPL.Model,

progress = true,
verbose = false,
callback = nothing,
kwargs...,
)
sampler = HMCSampler(kernel, metric, adaptor)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
)
)

ℓ = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi, model, ctxt))
d = LogDensityProblems.dimension(ℓ)
model = AbstractMCMC.LogDensityModel(ℓ)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change


if init_params === nothing
init_params = randn(rng, size(metric, 1))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
)
)

if spl.metric == nothing
metric = DiagEuclideanMetric(d)
else
metric = spl.metric
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric = spl.metric
metric = spl.metric

Comment on lines 148 to 154
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts = n_adapts,
kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts = n_adapts,
kwargs...)
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...)

Comment on lines 84 to 87
metric
integrator
kernel
adaptor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric
integrator
kernel
adaptor
metric::Any
integrator::Any
kernel::Any
adaptor::Any

Comment on lines 93 to 99
max_depth::Int=10,
Δ_max::Float64=1000.0,
init_ϵ::Float64=0.0,
metric=nothing,
integrator=Leapfrog,
kernel = NUTS_kernel{MultinomialTS, GeneralisedNoUTurn}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
max_depth::Int=10,
Δ_max::Float64=1000.0,
init_ϵ::Float64=0.0,
metric=nothing,
integrator=Leapfrog,
kernel = NUTS_kernel{MultinomialTS, GeneralisedNoUTurn}
)
max_depth::Int = 10,
Δ_max::Float64 = 1000.0,
init_ϵ::Float64 = 0.0,
metric = nothing,
integrator = Leapfrog,
kernel = NUTS_kernel{MultinomialTS,GeneralisedNoUTurn},
)

Comment on lines 101 to 103
return StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(TAP, integrator))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(TAP, integrator))
end
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator))
end

StepSizeAdaptor(TAP, integrator))
end
NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, kernel, adaptor)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end


function NUTS_kernel(integrator)
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Comment on lines 75 to 79
max_depth::Int=10,
Δ_max::Float64=1000.0,
init_ϵ::Float64=0.0,
metric=nothing,
integrator=Leapfrog)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
max_depth::Int=10,
Δ_max::Float64=1000.0,
init_ϵ::Float64=0.0,
metric=nothing,
integrator=Leapfrog)
max_depth::Int = 10,
Δ_max::Float64 = 1000.0,
init_ϵ::Float64 = 0.0,
metric = nothing,
integrator = Leapfrog,
)

NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor)
end

export NUTS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
export NUTS
export NUTS

function kernel(integrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(n_leapfrog)))
end
return kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return kernel
return kernel

Comment on lines 130 to 132
metric
integrator
kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric
integrator
kernel
metric::Any
integrator::Any
kernel::Any

function kernel(integrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedIntegrationTime(λ)))
end
return kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return kernel
return kernel

Comment on lines 182 to 188
n_adapts :: Int # number of samples with adaption for ϵ
TAP :: Float64 # target accept rate
λ :: Float64 # target leapfrog length
ϵ :: Float64 # (initial) step size
metric
integrator
kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
n_adapts :: Int # number of samples with adaption for ϵ
TAP :: Float64 # target accept rate
λ :: Float64 # target leapfrog length
ϵ :: Float64 # (initial) step size
metric
integrator
kernel
n_adapts::Int # number of samples with adaption for ϵ
TAP::Float64 # target accept rate
λ::Float64 # target leapfrog length
ϵ::Float64 # (initial) step size
metric::Any
integrator::Any
kernel::Any

Comment on lines 195 to 197
ϵ::Float64=0.0,
metric=nothing,
integrator=Leapfrog)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
ϵ::Float64=0.0,
metric=nothing,
integrator=Leapfrog)
ϵ::Float64 = 0.0,
metric = nothing,
integrator = Leapfrog,
)

Comment on lines 200 to 202
return StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(TAP, integrator))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(TAP, integrator))
end
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator))
end

return HMCDA(n_adapts, TAP, λ, ϵ, metric, integrator, kernel, adaptor)
end

export HMCDA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
export HMCDA
export HMCDA

progress = false # don't use AMCMC's progress-funtionality
)
vi = kwargs[:vi]
d = kwargs[:d]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
d = kwargs[:d]
d = kwargs[:d]

Comment on lines 130 to 133
metric
integrator
kernel
adaptor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric
integrator
kernel
adaptor
metric::Any
integrator::Any
kernel::Any
adaptor::Any

Comment on lines 183 to 190
n_adapts :: Int # number of samples with adaption for ϵ
TAP :: Float64 # target accept rate
λ :: Float64 # target leapfrog length
ϵ :: Float64 # (initial) step size
metric
integrator
kernel
adaptor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
n_adapts :: Int # number of samples with adaption for ϵ
TAP :: Float64 # target accept rate
λ :: Float64 # target leapfrog length
ϵ :: Float64 # (initial) step size
metric
integrator
kernel
adaptor
n_adapts::Int # number of samples with adaption for ϵ
TAP::Float64 # target accept rate
λ::Float64 # target leapfrog length
ϵ::Float64 # (initial) step size
metric::Any
integrator::Any
kernel::Any
adaptor::Any

@@ -246,4 +246,4 @@ function sample(
@info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h κ EBFMI_est average_acceptance_rate
end
return θs, stats
end
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

spl,
state;
n_adapts = n_adapts,
kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
kwargs...)
kwargs...,
)

state::HMCState;
nadapts::Int = 0,
kwargs...,
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
)
)

getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
function getdimensions(f::AbstractMCMC.LogDensityModel)
return LogDensityProblems.dimension(f.logdensity)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Comment on lines 97 to 104
println(typeof(hamiltonian)<:Hamiltonian)
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts = n_adapts,
kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
println(typeof(hamiltonian)<:Hamiltonian)
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts = n_adapts,
kwargs...)
println(typeof(hamiltonian) <: Hamiltonian)
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...)

src/AdvancedHMC.jl Show resolved Hide resolved
Comment on lines 81 to 85
return StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(TAP, integrator))
end
AHMC_NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(TAP, integrator))
end
AHMC_NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor)
end
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator))
end
AHMC_NUTS(
n_adapts,
TAP,
max_depth,
Δ_max,
init_ϵ,
metric,
integrator,
NUTS_kernel,
adaptor,
)
end

AHMC_NUTS(n_adapts, TAP, max_depth, Δ_max, init_ϵ, metric, integrator, NUTS_kernel, adaptor)
end

export AHMC_NUTS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
export AHMC_NUTS
export AHMC_NUTS

return AHMC_HMCDA(n_adapts, TAP, λ, ϵ, metric, integrator, kernel, adaptor)
end

export AHMC_HMCDA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
export AHMC_HMCDA
export AHMC_HMCDA

Comment on lines 83 to 84
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.alg.TAP, integrator))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.alg.TAP, integrator))
adaptor = StanHMCAdaptor(
MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.alg.TAP, integrator),
)

else
adaptor = spl.adaptor
n_adapts = kwargs[:n_adapts]
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

# Basic use
HMCSampler(algorithm) = HMCSampler(algorithm, nothing, nothing, nothing, nothing)
# Expert use
HMCSampler(integrator, kernel, metric, adaptor) = HMCSampler(Custom_alg, integrator, kernel, metric, adaptor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
HMCSampler(integrator, kernel, metric, adaptor) = HMCSampler(Custom_alg, integrator, kernel, metric, adaptor)
HMCSampler(integrator, kernel, metric, adaptor) =
HMCSampler(Custom_alg, integrator, kernel, metric, adaptor)

##########
# Custom #
##########
struct Custom_alg<:SamplingAlgorithm end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
struct Custom_alg<:SamplingAlgorithm end
struct Custom_alg <: SamplingAlgorithm end

Comment on lines 78 to 80
max_depth::Int=10,
Δ_max::Float64=1000.0,
ϵ::Float64=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
max_depth::Int=10,
Δ_max::Float64=1000.0,
ϵ::Float64=0.0)
max_depth::Int = 10,
Δ_max::Float64 = 1000.0,
ϵ::Float64 = 0.0,
)

Comment on lines 161 to 165
function HMCDA(
n_adapts::Int,
TAP::Float64,
λ::Float64;
ϵ::Float64=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function HMCDA(
n_adapts::Int,
TAP::Float64,
λ::Float64;
ϵ::Float64=0.0)
function HMCDA(n_adapts::Int, TAP::Float64, λ::Float64; ϵ::Float64 = 0.0)

Comment on lines 174 to 176
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator),
StepSizeAdaptor(alg.TAP, integrator))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator),
StepSizeAdaptor(alg.TAP, integrator))
end
return StanHMCAdaptor(
MassMatrixAdaptor(metric, integrator),
StepSizeAdaptor(alg.TAP, integrator),
)
end


function make_kernel(alg::NUTS_alg, integrator)
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end


function make_kernel(alg::HMC_alg, integrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(alg.n_leapfrog)))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end


function make_kernel(alg::HMCDA_alg, integrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedIntegrationTime(alg.λ)))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Comment on lines 108 to 114
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts=n_adapts,
kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return AbstractMCMC.step(
rng,
model,
spl,
state;
n_adapts=n_adapts,
kwargs...)
return AbstractMCMC.step(rng, model, spl, state; n_adapts = n_adapts, kwargs...)

state::HMCState;
nadapts::Int = 0,
nadapts::Int=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
nadapts::Int=0,
nadapts::Int = 0,

@@ -217,14 +143,19 @@ function AbstractMCMC.step(
h, κ, isadapted = adapt!(h, κ, adaptor, i, nadapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt = isadapted,))

# Convert variables back
vii_t = DynamicPPL.unflatten(vi_t, t.z.θ)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
vii_t = DynamicPPL.unflatten(vi_t, t.z.θ)
vii_t = DynamicPPL.unflatten(vi_t, t.z.θ)

Comment on lines 43 to 47
vi = DynamicPPL.VarInfo(model, ctxt)
vi_t = DynamicPPL.link!!(vi, model)
logdensityfunction = DynamicPPL.LogDensityFunction(vi_t, model, ctxt)
logdensityproblem = LogDensityProblemsAD.ADgradient(logdensityfunction)
logdensitymodel = AbstractMCMC.LogDensityModel(logdensityproblem)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should be independent of DynamicPPL APIs or internals from MCMC libraries. Maybe this should be transferred to DynamicPPL/AbstractMCMC.

Comment on lines 182 to 183
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.alg.δ, integrator))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.alg.δ, integrator))
adaptor = StanHMCAdaptor(
MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.alg.δ, integrator),
)

state = HMCState(0, t, h.metric, κ, adaptor)

state = HMCState(0, t, metric, κ, adaptor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@@ -302,4 +312,4 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw
elseif verbose && isadapted && i == nadapts
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric
end
end
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Comment on lines 155 to 158
n_adapts :: Int # number of samples with adaption for ϵ
δ :: Float64 # target accept rate
λ :: Float64 # target leapfrog length
ϵ :: Float64 # (initial) step size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
n_adapts :: Int # number of samples with adaption for ϵ
δ :: Float64 # target accept rate
λ :: Float64 # target leapfrog length
ϵ :: Float64 # (initial) step size
n_adapts::Int # number of samples with adaption for ϵ
δ::Float64 # target accept rate
λ::Float64 # target leapfrog length
ϵ::Float64 # (initial) step size

Comment on lines 161 to 165
function HMCDA(
n_adapts::Int,
δ::Float64,
λ::Float64;
ϵ::Float64=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function HMCDA(
n_adapts::Int,
δ::Float64,
λ::Float64;
ϵ::Float64=0.0)
function HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0)

Comment on lines 174 to 176
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator),
StepSizeAdaptor(alg.δ, integrator))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return StanHMCAdaptor(MassMatrixAdaptor(metric, integrator),
StepSizeAdaptor(alg.δ, integrator))
end
return StanHMCAdaptor(
MassMatrixAdaptor(metric, integrator),
StepSizeAdaptor(alg.δ, integrator),
)
end

Base.@kwdef struct HMCSampler{I,K,M,A} <: AbstractMCMC.AbstractSampler
alg::SamplingAlgorithm
"[`integrator`](@ref)."
integrator::I=nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
integrator::I=nothing
integrator::I = nothing

"[`integrator`](@ref)."
integrator::I=nothing
"[`AbstractMCMCKernel`](@ref)."
kernel::K=nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
kernel::K=nothing
kernel::K = nothing

"[`AbstractMCMCKernel`](@ref)."
kernel::K=nothing
"[`AbstractMetric`](@ref)."
metric::M=nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
metric::M=nothing
metric::M = nothing

"[`AbstractMetric`](@ref)."
metric::M=nothing
"[`AbstractAdaptor`](@ref)."
adaptor::A=nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
adaptor::A=nothing
adaptor::A = nothing

Comment on lines 160 to 164
function HMCDA(
n_adapts::Int,
δ::Float64,
λ::Float64;
ϵ::Float64=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function HMCDA(
n_adapts::Int,
δ::Float64,
λ::Float64;
ϵ::Float64=0.0)
function HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0)

To access the updated fields use the resulting [`HMCState`](@ref).
"""
Base.@kwdef struct HMCSampler{I,K,M,A} <: AbstractMCMC.AbstractSampler
alg::HMCAlgorithm=Custom_alg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
alg::HMCAlgorithm=Custom_alg
alg::HMCAlgorithm = Custom_alg

##########
# Custom #
##########
struct Custom_alg<:HMCAlgorithm end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
struct Custom_alg<:HMCAlgorithm end
struct Custom_alg <: HMCAlgorithm end

Comment on lines 74 to 78
max_depth::Int=10,
Δ_max::Float64=1000.0,
ϵ::Float64=0.0)
return HMCSampler(;alg=NUTS_alg(n_adapts, δ, max_depth, Δ_max, ϵ))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
max_depth::Int=10,
Δ_max::Float64=1000.0,
ϵ::Float64=0.0)
return HMCSampler(;alg=NUTS_alg(n_adapts, δ, max_depth, Δ_max, ϵ))
end
max_depth::Int = 10,
Δ_max::Float64 = 1000.0,
ϵ::Float64 = 0.0,
)
return HMCSampler(; alg = NUTS_alg(n_adapts, δ, max_depth, Δ_max, ϵ))
end

ϵ::Float64,
n_leapfrog::Int)

return HMCSampler(;alg=HMC_alg(ϵ, n_leapfrog))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return HMCSampler(;alg=HMC_alg(ϵ, n_leapfrog))
return HMCSampler(; alg = HMC_alg(ϵ, n_leapfrog))

Comment on lines 157 to 162
function HMCDA(
n_adapts::Int,
δ::Float64,
λ::Float64;
ϵ::Float64=0.0)
return HMCSampler(;alg=HMCDA_alg(n_adapts, δ, λ, ϵ))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function HMCDA(
n_adapts::Int,
δ::Float64,
λ::Float64;
ϵ::Float64=0.0)
return HMCSampler(;alg=HMCDA_alg(n_adapts, δ, λ, ϵ))
function HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0)
return HMCSampler(; alg = HMCDA_alg(n_adapts, δ, λ, ϵ))

Comment on lines 41 to 42
d = d=LogDensityProblems.dimension(logdensity)
metric = make_metric(spl; d=d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
d = d=LogDensityProblems.dimension(logdensity)
metric = make_metric(spl; d=d)
d = d = LogDensityProblems.dimension(logdensity)
metric = make_metric(spl; d = d)

Comment on lines 49 to 52
integrator = make_integrator(spl;
rng=rng,
hamiltonian=hamiltonian,
init_params=init_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
integrator = make_integrator(spl;
rng=rng,
hamiltonian=hamiltonian,
init_params=init_params)
integrator = make_integrator(
spl;
rng = rng,
hamiltonian = hamiltonian,
init_params = init_params,
)

@@ -0,0 +1,204 @@
abstract type AbstractHMCSampler <:AbstractMCMC.AbstractSampler end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
abstract type AbstractHMCSampler <:AbstractMCMC.AbstractSampler end
abstract type AbstractHMCSampler <: AbstractMCMC.AbstractSampler end

"""
Base.@kwdef struct CustomHMC{I,K,M,A} <: AbstractMCMC.AbstractSampler
"[`integrator`](@ref)."
integrator::I=Leapfrog
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
integrator::I=Leapfrog
integrator::I = Leapfrog

Comment on lines +61 to +65
max_depth::Int=10 # maximum tree depth
Δ_max::Float64=1000.0 # maximum error
init_ϵ::Float64=0.0 # (initial) step size
integrator_method=Leapfrog # integrator method
metric_type=DiagEuclideanMetric # metric type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
max_depth::Int=10 # maximum tree depth
Δ_max::Float64=1000.0 # maximum error
init_ϵ::Float64=0.0 # (initial) step size
integrator_method=Leapfrog # integrator method
metric_type=DiagEuclideanMetric # metric type
max_depth::Int = 10 # maximum tree depth
Δ_max::Float64 = 1000.0 # maximum error
init_ϵ::Float64 = 0.0 # (initial) step size
integrator_method = Leapfrog # integrator method
metric_type = DiagEuclideanMetric # metric type

Comment on lines +173 to +175
function make_adaptor(spl::Union{NUTS_alg, HMCDA_alg}, metric, integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.δ, integrator))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function make_adaptor(spl::Union{NUTS_alg, HMCDA_alg}, metric, integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(spl.δ, integrator))
function make_adaptor(spl::Union{NUTS_alg,HMCDA_alg}, metric, integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(spl.δ, integrator))

StepSizeAdaptor(spl.δ, integrator))
n_adapts = spl.n_adapts
return n_adapts, adaptor
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end


function make_adaptor(spl::HMC_alg, metric, integrator)
return 0, NoAdaptation()
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

return 0, NoAdaptation()
end

function make_adaptor(spl::CustomHMC, metric, integrator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function make_adaptor(spl::CustomHMC, metric, integrator)
function make_adaptor(spl::CustomHMC, metric, integrator)


function make_adaptor(spl::CustomHMC, metric, integrator)
return spl.n_adapts, spl.adaptor
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Comment on lines +147 to +148
function make_integrator(rng, spl::Union{HMC_alg, NUTS_alg, HMCDA_alg},
hamiltonian, init_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function make_integrator(rng, spl::Union{HMC_alg, NUTS_alg, HMCDA_alg},
hamiltonian, init_params)
function make_integrator(
rng,
spl::Union{HMC_alg,NUTS_alg,HMCDA_alg},
hamiltonian,
init_params,
)


#########

function make_metric(spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, logdensity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function make_metric(spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, logdensity)
function make_metric(spl::Union{HMC_alg,NUTS_alg,HMCDA_alg}, logdensity)

@JaimeRZP JaimeRZP closed this Jun 27, 2023
@JaimeRZP JaimeRZP deleted the no_glue_code branch June 27, 2023 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

More friendly default sample interface
2 participants