-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No glue code #319
No glue code #319
Changes from 6 commits
43cfe65
3158bb6
37d6831
6fe6436
304d401
46f0803
2f6f2c1
dfd5e74
3038441
deb0555
00b837a
ce96cac
1f8c5a7
612e10b
3bbc668
1bffe99
b941529
8b1f962
0f45cc8
3e1c403
8fa9fcb
c582abf
39941fd
3684a1e
62c2096
1893e3f
1af622c
acaa289
80f0d8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -131,6 +131,68 @@ sample( | |||||||||||||||||||||||||||
(pm_next!) = pm_next!, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
### | ||||||||||||||||||||||||||||
# Allows to pass Turing model to build Hamiltonian | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
function sample( | ||||||||||||||||||||||||||||
model::DynamicPPL.Model, | ||||||||||||||||||||||||||||
metric::AbstractMetric, | ||||||||||||||||||||||||||||
κ::AbstractMCMCKernel, | ||||||||||||||||||||||||||||
θ::AbstractVecOrMat{<:AbstractFloat}, | ||||||||||||||||||||||||||||
n_samples::Int, | ||||||||||||||||||||||||||||
adaptor::AbstractAdaptor = NoAdaptation(), | ||||||||||||||||||||||||||||
n_adapts::Int = min(div(n_samples, 10), 1_000); | ||||||||||||||||||||||||||||
drop_warmup = false, | ||||||||||||||||||||||||||||
verbose::Bool = true, | ||||||||||||||||||||||||||||
progress::Bool = false, | ||||||||||||||||||||||||||||
(pm_next!)::Function = pm_next!, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
ctxt = model.context | ||||||||||||||||||||||||||||
vi = DynamicPPL.VarInfo(model, ctxt) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# We will need to implement this but it is going to be | ||||||||||||||||||||||||||||
# Interesting how to plug the transforms along the sampling | ||||||||||||||||||||||||||||
# processes | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
#vi_t = Turing.link!!(vi, model) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
ℓ = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(vi, model, ctxt)) | ||||||||||||||||||||||||||||
h = Hamiltonian(metric, ℓ) | ||||||||||||||||||||||||||||
return sample( | ||||||||||||||||||||||||||||
GLOBAL_RNG, | ||||||||||||||||||||||||||||
h, | ||||||||||||||||||||||||||||
κ, | ||||||||||||||||||||||||||||
θ, | ||||||||||||||||||||||||||||
n_samples, | ||||||||||||||||||||||||||||
adaptor, | ||||||||||||||||||||||||||||
n_adapts; | ||||||||||||||||||||||||||||
drop_warmup = drop_warmup, | ||||||||||||||||||||||||||||
verbose = verbose, | ||||||||||||||||||||||||||||
progress = progress, | ||||||||||||||||||||||||||||
(pm_next!) = pm_next!, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
function sample(model::DynamicPPL.Model, ϵ::Number, TAP::Number, n_samples::Int, n_adapts::Int; | ||||||||||||||||||||||||||||
initial_θ=initial_θ, progress=true, kwargs...) | ||||||||||||||||||||||||||||
ctxt = model.context | ||||||||||||||||||||||||||||
vi = VarInfo(model, ctxt) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
dists = _get_dists(vi) | ||||||||||||||||||||||||||||
dist_lengths = [length(dist) for dist in dists] | ||||||||||||||||||||||||||||
vsyms = _name_variables(vi, dist_lengths) | ||||||||||||||||||||||||||||
d = length(vsyms) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
metric = DiagEuclideanMetric(d) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||
integrator = Leapfrog(ϵ) | ||||||||||||||||||||||||||||
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(TAP, integrator)) | ||||||||||||||||||||||||||||
return sample(model, metric, proposal, initial_θ, n_samples, adaptor, n_adapts; | ||||||||||||||||||||||||||||
progress=progress, kwargs...) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
### | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
sample( | ||||||||||||||||||||||||||||
rng::AbstractRNG, | ||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,19 @@ | ||||||
function _get_dists(vi::VarInfo) | ||||||
mds = values(vi.metadata) | ||||||
return [md.dists[1] for md in mds] | ||||||
end | ||||||
|
||||||
function _name_variables(vi::VarInfo, dist_lengths::AbstractVector) | ||||||
vsyms = keys(vi) | ||||||
names = [] | ||||||
for (vsym, dist_length) in zip(vsyms, dist_lengths) | ||||||
if dist_length==1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
name = [vsym] | ||||||
append!(names, name) | ||||||
else | ||||||
name = [DynamicPPL.VarName(Symbol(vsym, i,)) for i in 1:dist_length] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
append!(names, name) | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
return names | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶