Skip to content

Commit

Permalink
renamed _getmodel to getmodel, _setmodel to setmodel, and
Browse files Browse the repository at this point in the history
`_varinfo` to `varinfo_from_logdensityfn`
  • Loading branch information
torfjelde committed Jun 26, 2024
1 parent 414a077 commit 89bc2e1
Showing 1 changed file with 37 additions and 20 deletions.
57 changes: 37 additions & 20 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,41 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
return transition_to_turing(parent(f), transition)
end

_getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = _getmodel(parent(f))
_getmodel(f::DynamicPPL.LogDensityFunction) = f.model
"""
getmodel(f)
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

# FIXME: We'll have to overload this for every AD backend since some of the AD backends
# will cache certain parts of a given model, e.g. the tape, which results in a discrepancy
# between the primal (forward) and dual (backward).
function _setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model)
return Accessors.@set f.= _setmodel(f.ℓ, model)
"""
setmodel(f, model)
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model)
return Accessors.@set f.= setmodel(f.ℓ, model)
end
function _setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

_varinfo(f::LogDensityProblemsAD.ADGradientWrapper) = _varinfo(parent(f))
function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))
end
_varinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo

function varinfo(state::TuringState)
θ = getparams(_getmodel(state.logdensity), state.state)
θ = getparams(getmodel(state.logdensity), state.state)
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(_varinfo(state.logdensity), θ)
end
Expand Down Expand Up @@ -67,17 +84,14 @@ function recompute_logprob!!(
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:ExternalSampler},
state
state,
)
# Re-using the log-density function from the `state` and updating only the `model` field,
# since the `model` might now contain different conditioning values.
f = _setmodel(state.logdensity, model)
f = setmodel(state.logdensity, model)
# Recompute the log-probability with the new `model`.
state_inner = recompute_logprob!!(
rng,
AbstractMCMC.LogDensityModel(f),
sampler.alg.sampler,
state.state
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
)
return state_to_turing(f, state_inner)
end
Expand All @@ -86,15 +100,13 @@ function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedHMC.AbstractHMCSampler,
state::AdvancedHMC.HMCState
state::AdvancedHMC.HMCState,
)
# Construct hamiltionian.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
# Re-compute the log-probability and gradient.
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian,
state.transition.z.θ,
state.transition.z.r,
hamiltonian, state.transition.z.θ, state.transition.z.r
)
end

Expand All @@ -115,7 +127,7 @@ function AbstractMCMC.step(
sampler_wrapper::Sampler{<:ExternalSampler};
initial_state=nothing,
initial_params=nothing,
kwargs...
kwargs...,
)
alg = sampler_wrapper.alg
sampler = alg.sampler
Expand Down Expand Up @@ -145,7 +157,12 @@ function AbstractMCMC.step(
)
else
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, initial_state; initial_params, kwargs...
rng,
AbstractMCMC.LogDensityModel(f),
sampler,
initial_state;
initial_params,
kwargs...,
)
end
# Update the `state`
Expand All @@ -157,7 +174,7 @@ function AbstractMCMC.step(
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
Expand Down

0 comments on commit 89bc2e1

Please sign in to comment.