From 5948253f792aeec0e7418a438a5afddb0cbf0c4e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 23 Sep 2024 14:28:08 +0100 Subject: [PATCH 01/70] Replace old Gibbs sampler with the experimental one. --- src/Turing.jl | 1 - src/experimental/Experimental.jl | 16 - src/experimental/gibbs.jl | 488 -------------------------- src/mcmc/Inference.jl | 7 +- src/mcmc/abstractmcmc.jl | 4 + src/mcmc/gibbs.jl | 580 ++++++++++++++++++++++--------- src/mcmc/gibbs_conditional.jl | 88 ----- test/experimental/gibbs.jl | 270 -------------- test/mcmc/Inference.jl | 14 +- test/mcmc/ess.jl | 10 +- test/mcmc/gibbs.jl | 325 +++++++++++++++-- test/mcmc/gibbs_conditional.jl | 172 --------- test/mcmc/hmc.jl | 12 +- test/mcmc/mh.jl | 10 +- test/runtests.jl | 5 - test/skipped/explicit_ret.jl | 2 +- 16 files changed, 751 insertions(+), 1253 deletions(-) delete mode 100644 src/experimental/Experimental.jl delete mode 100644 src/experimental/gibbs.jl delete mode 100644 src/mcmc/gibbs_conditional.jl delete mode 100644 test/experimental/gibbs.jl delete mode 100644 test/mcmc/gibbs_conditional.jl diff --git a/src/Turing.jl b/src/Turing.jl index 8dfb8df286..8fcee6c185 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -86,7 +86,6 @@ export @model, # modelling Emcee, ESS, Gibbs, - GibbsConditional, HMC, # Hamiltonian-like sampling SGLD, SGHMC, diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl deleted file mode 100644 index 518538e6c3..0000000000 --- a/src/experimental/Experimental.jl +++ /dev/null @@ -1,16 +0,0 @@ -module Experimental - -using Random: Random -using AbstractMCMC: AbstractMCMC -using DynamicPPL: DynamicPPL, VarName -using Accessors: Accessors - -using DocStringExtensions: TYPEDFIELDS -using Distributions - -using ..Turing: Turing -using ..Turing.Inference: gibbs_rerun, InferenceAlgorithm - -include("gibbs.jl") - -end diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl deleted file mode 100644 index 596e6e283b..0000000000 --- a/src/experimental/gibbs.jl +++ /dev/null @@ -1,488 +0,0 @@ -# Basically like a `DynamicPPL.FixedContext` but -# 1. Hijacks the tilde pipeline to fix variables. -# 2. Computes the log-probability of the fixed variables. -# -# Purpose: avoid triggering resampling of variables we're conditioning on. -# - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. -# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to -# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable -# rather than only for the "true" observations. -# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline -# rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values - context::Ctx -end - -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) - -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::GibbsContext) = context.context -DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) - -# has and get -has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) -function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) -end - -get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) -function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(get_conditioned_gibbs, context), vns) -end - -# Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) -end - -function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) -end - -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -# Needed to support broadcasting over columns for `MultivariateDistribution`s. -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, - x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - -function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi -) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) -end - - -""" - preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) - -Returns the preferred value type for a variable with the given `varinfo`. -""" -preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict -preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple -function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) - # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. - namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,typeof(identity)} - end - return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict -end - -""" - condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) - -Return a `GibbsContext` with the given values treated as conditioned. - -# Arguments -- `context::DynamicPPL.AbstractContext`: The context to condition. -- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. - If multiple values are provided, we recursively condition on each of them. -""" -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. -function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs( - condition_gibbs(context, value), - values... - ) -end - -# For `DynamicPPL.AbstractVarInfo` we just extract the values. -""" - condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) - -Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. -""" -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) -end -function condition_gibbs( - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo, - varinfos::DynamicPPL.AbstractVarInfo... -) - return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) -end -# Allow calling this on a `DynamicPPL.Model` directly. -function condition_gibbs(model::DynamicPPL.Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end - - -""" - make_conditional_model(model, varinfo, varinfos) - -Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. - -# Examples -```julia-repl -julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); - -julia> # A separate varinfo for each variable in `model`. - varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); - -julia> # The varinfo we want to NOT condition on. - target_varinfo = first(varinfos); - -julia> # Results in a model with only `m` conditioned. - conditioned_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos); - -julia> result = conditioned_model(); - -julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` -true - -julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` -true -``` -""" -function make_conditional(model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos) - # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs( - model, - filter(Base.Fix1(!==, target_varinfo), varinfos)... - ) -end -# Assumes the ones given are the ones to condition on. -function make_conditional(model::DynamicPPL.Model, varinfos) - return condition_gibbs( - model, - varinfos... - ) -end - -# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` -# or an `AbstractInferenceAlgorithm`. -wrap_algorithm_maybe(x) = x -wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) - -""" - Gibbs - -A type representing a Gibbs sampler. - -# Fields -$(TYPEDFIELDS) -""" -struct Gibbs{V,A} <: InferenceAlgorithm - "varnames representing variables for each sampler" - varnames::V - "samplers for each entry in `varnames`" - samplers::A -end - -# NamedTuple -Gibbs(; algs...) = Gibbs(NamedTuple(algs)) -function Gibbs(algs::NamedTuple) - return Gibbs( - map(s -> VarName{s}(), keys(algs)), - map(wrap_algorithm_maybe, values(algs)), - ) -end - -# AbstractDict -function Gibbs(algs::AbstractDict) - return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) -end -function Gibbs(algs::Pair...) - return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) -end - -# TODO: Remove when no longer needed. -DynamicPPL.getspace(::Gibbs) = () - -struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} - vi::V - states::S -end - -_maybevec(x) = vec(x) # assume it's iterable -_maybevec(x::Tuple) = [x...] -_maybevec(x::VarName) = [x] - -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - vi_base::DynamicPPL.AbstractVarInfo; - initial_params=nothing, - kwargs..., -) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers - - # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(model) - - # Simple way of setting the initial parameters: set them in the `vi_base` - # if they are given so they propagate to the subset varinfos used by each sampler. - if initial_params !== nothing - vi_base = DynamicPPL.unflatten(vi_base, initial_params) - end - - # Create the varinfos for each sampler. - varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) - initial_params_all = if initial_params === nothing - fill(nothing, length(varnames)) - else - # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], varinfos) - end - - # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local - # Construct the conditional model. - model_local = make_conditional(model, varinfo_local, varinfos) - - # Take initial step. - new_state_local = last(AbstractMCMC.step( - rng, model_local, sampler_local; - # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. - # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, - kwargs... - )) - - # Return the new state and the invlinked `varinfo`. - vi_local_state = Turing.Inference.varinfo(new_state_local) - vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink(vi_local_state, sampler_local, model_local) - else - vi_local_state - end - return (new_state_local, vi_local_state_linked) - end - - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - state::GibbsState; - kwargs..., -) - alg = spl.alg - samplers = alg.samplers - states = state.states - varinfos = map(Turing.Inference.varinfo, state.states) - @assert length(samplers) == length(state.states) - - # TODO: move this into a recursive function so we can unroll when reasonable? - for index = 1:length(samplers) - # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, - model, - samplers, - states, - varinfos, - index; - kwargs..., - ) - - # Update the `states` and `varinfos`. - states = Accessors.setindex(states, new_state_local, index) - varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) - end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, - merge(vi_base, first(varinfos)), - firstindex(varinfos), - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) - # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide - # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact - # same `selector` as before but now with `rerun` set to `true` if needed. - return Accessors.@set sampler.selector.rerun = true -end - -# Interface we need a sampler to implement to work as a component in a Gibbs sampler. -""" - gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - -Check if the log-probability of the destination model needs to be recomputed. - -Defaults to `true` -""" -function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - return true -end - -# TODO: Remove `rng`? -function Turing.Inference.recompute_logprob!!( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler, - state -) - varinfo = Turing.Inference.varinfo(state) - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler) - # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed - # `varinfo`, even if `varinfo` was linked. - varinfo_new = last(DynamicPPL.evaluate!!( - model, - varinfo, - # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. - DynamicPPL.SamplingContext(rng, sampler_rerun) - )) - # Update the state we're about to use if need be. - # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - return Turing.Inference.gibbs_state(model, sampler, state, varinfo_new) -end - -function gibbs_step_inner( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - samplers, - states, - varinfos, - index; - kwargs..., -) - # Needs to do a a few things. - sampler_local = samplers[index] - state_local = states[index] - varinfo_local = varinfos[index] - - # Make sure that all `varinfos` are linked. - varinfos_invlinked = map(varinfos) do vi - # NOTE: This is immutable linking! - # TODO: Do we need the `istrans` check here or should we just always use `invlink`? - # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 - DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - end - varinfo_local_invlinked = varinfos_invlinked[index] - - # 1. Create conditional model. - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) - - # Extract the previous sampler and state. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - state_previous = states[index == 1 ? length(states) : index - 1] - - # 1. Re-run the sampler if needed. - if gibbs_requires_recompute_logprob( - model_local, - sampler_local, - sampler_previous, - state_local, - state_previous - ) - state_local = Turing.Inference.recompute_logprob!!( - rng, - model_local, - sampler_local, - state_local, - ) - end - - # 2. Take step with local sampler. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local, - state_local; - kwargs..., - ), - ) - - # 3. Extract the new varinfo. - # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = Turing.Inference.varinfo(new_state_local) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - - # TODO: alternatively, we can return `states_new, varinfos_new, index_new` - return (new_state_local, varinfo_local_state_invlinked) -end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index b7bdf206b9..4955598711 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -46,7 +46,6 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling - GibbsConditional, HMC, SGLD, PolynomialStepsize, @@ -63,7 +62,6 @@ export InferenceAlgorithm, observe, dot_observe, predict, - isgibbscomponent, externalsampler ####################### @@ -526,22 +524,21 @@ end # Concrete algorithm implementations. # ####################################### +include("abstractmcmc.jl") include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") include("particle_mcmc.jl") -include("gibbs_conditional.jl") include("gibbs.jl") include("sghmc.jl") include("emcee.jl") -include("abstractmcmc.jl") ################ # Typing tools # ################ -for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs, :Emcee) +for alg in (:SMC, :PG, :MH, :IS, :ESS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index a350d2908f..965c797060 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -27,6 +27,10 @@ function varinfo(state::TuringState) # TODO: Do we need to link here first? return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) end +varinfo(state::AbstractVarInfo) = state +# TODO(mhauru) Could we have a type bound on the argument below, for documentation purposes? +varinfo(state) = state.vi + # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 736845b678..fb05b64757 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,101 +1,220 @@ -### -### Gibbs samplers / compositional samplers. -### +# Basically like a `DynamicPPL.FixedContext` but +# 1. Hijacks the tilde pipeline to fix variables. +# 2. Computes the log-probability of the fixed variables. +# +# Purpose: avoid triggering resampling of variables we're conditioning on. +# - Using standard `DynamicPPL.condition` results in conditioned variables being treated +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to +# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable +# rather than only for the "true" observations. +# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline +# rather than the `observe` pipeline for the conditioned variables. +struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + values::Values + context::Ctx +end -""" - isgibbscomponent(alg) +Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) -Determine whether algorithm `alg` is allowed as a Gibbs component. -""" -isgibbscomponent(alg) = false +DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::GibbsContext) = context.context +function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) + return GibbsContext(context.values, childcontext) +end -isgibbscomponent(::ESS) = true -isgibbscomponent(::GibbsConditional) = true -isgibbscomponent(::Hamiltonian) = true -isgibbscomponent(::MH) = true -isgibbscomponent(::PG) = true +# has and get +function has_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.hasvalue(context.values, vn) +end +function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(has_conditioned_gibbs, context), vns) +end -const TGIBBS = Union{InferenceAlgorithm,GibbsConditional} +function get_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.getvalue(context.values, vn) +end +function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(get_conditioned_gibbs, context), vns) +end -""" - Gibbs(algs...) +# Tilde pipeline +function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end -Compositional MCMC interface. Gibbs sampling combines one or more -sampling algorithms, each of which samples from a different set of -variables in a model. + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) +end -Example: -```julia -@model function gibbs_example(x) - v1 ~ Normal(0,1) - v2 ~ Categorical(5) +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) end -# Use PG for a 'v2' variable, and use HMC for the 'v1' variable. -# Note that v2 is discrete, so the PG sampler is more appropriate -# than is HMC. -alg = Gibbs(HMC(0.2, 3, :v1), PG(20, :v2)) -``` +# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. +make_broadcastable(x) = x +make_broadcastable(dist::Distribution) = tuple(dist) -One can also pass the number of iterations for each Gibbs component using the following syntax: -- `alg = Gibbs((HMC(0.2, 3, :v1), n_hmc), (PG(20, :v2), n_pg))` -where `n_hmc` and `n_pg` are the number of HMC and PG iterations for each Gibbs iteration. +# Need the following two methods to properly support broadcasting over columns. +broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) +function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) + return loglikelihood(dist, x) +end -Tips: -- `HMC` and `NUTS` are fast samplers and can throw off particle-based -methods like Particle Gibbs. You can increase the effectiveness of particle sampling by including -more particles in the particle sampler. -""" -struct Gibbs{space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} <: InferenceAlgorithm - algs::A # component sampling algorithms - iterations::B - function Gibbs{space,N,A,B}( - algs::A, iterations::B - ) where {space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} - all(isgibbscomponent, algs) || - error("all algorithms have to support Gibbs sampling") - return new{space,N,A,B}(algs, iterations) +# Needed to support broadcasting over columns for `MultivariateDistribution`s. +reconstruct_getvalue(dist, x) = x +function reconstruct_getvalue( + dist::MultivariateDistribution, x::AbstractVector{<:AbstractVector{<:Real}} +) + return reduce(hcat, x[2:end]; init=x[1]) +end + +function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), right, left, vns, vi + ) end -function Gibbs(alg1::TGIBBS, algrest::Vararg{TGIBBS,N}) where {N} - algs = (alg1, algrest...) - iterations = ntuple(Returns(1), Val(N + 1)) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +function DynamicPPL.dot_tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi + ) end -function Gibbs(arg1::Tuple{<:TGIBBS,Int}, argrest::Vararg{Tuple{<:TGIBBS,Int},N}) where {N} - allargs = (arg1, argrest...) - algs = map(first, allargs) - iterations = map(last, allargs) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +""" + preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) + +Returns the preferred value type for a variable with the given `varinfo`. +""" +preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict +preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple +function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) + # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. + namedtuple_compatible = all(varinfo.metadata) do md + eltype(md.vns) <: VarName{<:Any,typeof(identity)} + end + return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) + +Return a `GibbsContext` with the given values treated as conditioned. -Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that -the `Gibbs` sampler iterates through for each `step!`. +# Arguments +- `context::DynamicPPL.AbstractContext`: The context to condition. +- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. + If multiple values are provided, we recursively condition on each of them. """ -struct GibbsState{V<:VarInfo,S<:Tuple{Vararg{Sampler}},T} - vi::V - samplers::S - states::T +condition_gibbs(context::DynamicPPL.AbstractContext) = context +# For `NamedTuple` and `AbstractDict` we just construct the context. +function condition_gibbs( + context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict} +) + return GibbsContext(values, context) +end +# If we get more than one argument, we just recurse. +function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) + return condition_gibbs(condition_gibbs(context, value), values...) end -# extract varinfo object from state +# For `DynamicPPL.AbstractVarInfo` we just extract the values. """ - gibbs_varinfo(model, sampler, state) + condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) -Return the variables corresponding to the current `state` of the Gibbs component `sampler`. +Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. """ -gibbs_varinfo(model, sampler, state) = varinfo(state) -varinfo(state) = state.vi -varinfo(state::AbstractVarInfo) = state +function condition_gibbs( + context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo +) + return condition_gibbs( + context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)) + ) +end +function condition_gibbs( + context::DynamicPPL.AbstractContext, + varinfo::DynamicPPL.AbstractVarInfo, + varinfos::DynamicPPL.AbstractVarInfo..., +) + return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) +end +# Allow calling this on a `DynamicPPL.Model` directly. +function condition_gibbs(model::DynamicPPL.Model, values...) + return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) +end + +""" + make_conditional_model(model, varinfo, varinfos) + +Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. + +# Examples +```julia-repl +julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); + +julia> # A separate varinfo for each variable in `model`. + varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); + +julia> # The varinfo we want to NOT condition on. + target_varinfo = first(varinfos); + +julia> # Results in a model with only `m` conditioned. + conditioned_model = make_conditional(model, target_varinfo, varinfos); + +julia> result = conditioned_model(); + +julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` +true + +julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` +true +``` +""" +function make_conditional( + model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos +) + # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. + return condition_gibbs(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) +end +# Assumes the ones given are the ones to condition on. +function make_conditional(model::DynamicPPL.Model, varinfos) + return condition_gibbs(model, varinfos...) +end + +# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` +# or an `AbstractInferenceAlgorithm`. +wrap_algorithm_maybe(x) = x +wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) """ gibbs_state(model, sampler, state, varinfo) @@ -130,122 +249,263 @@ function gibbs_state( end """ - gibbs_rerun(prev_alg, alg) + Gibbs -Check if the model should be rerun to recompute the log density before sampling with the -Gibbs component `alg` and after sampling from Gibbs component `prev_alg`. +A type representing a Gibbs sampler. -By default, the function returns `true`. +# Fields +$(TYPEDFIELDS) """ -gibbs_rerun(prev_alg, alg) = true +struct Gibbs{V,A} <: InferenceAlgorithm + "varnames representing variables for each sampler" + varnames::V + "samplers for each entry in `varnames`" + samplers::A +end + +# NamedTuple +Gibbs(; algs...) = Gibbs(NamedTuple(algs)) +function Gibbs(algs::NamedTuple) + return Gibbs( + map(s -> VarName{s}(), keys(algs)), map(wrap_algorithm_maybe, values(algs)) + ) +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or one of the standard `Hamiltonian` algorithms -gibbs_rerun(::GibbsConditional, ::MH) = false -gibbs_rerun(::Hamiltonian, ::MH) = false +# AbstractDict +function Gibbs(algs::AbstractDict) + return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) +end +function Gibbs(algs::Pair...) + return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or a `MH` algorithm -gibbs_rerun(::MH, ::Hamiltonian) = false -gibbs_rerun(::GibbsConditional, ::Hamiltonian) = false +# TODO: Remove when no longer needed. +DynamicPPL.getspace(::Gibbs) = () -# do not have to recompute `vi.logp` since it is not used in `step` -gibbs_rerun(prev_alg, ::GibbsConditional) = false +struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} + vi::V + states::S +end -# Do not recompute `vi.logp` since it is reset anyway in `step` -gibbs_rerun(prev_alg, ::PG) = false +_maybevec(x) = vec(x) # assume it's iterable +_maybevec(x::Tuple) = [x...] +_maybevec(x::VarName) = [x] -# Initialize the Gibbs sampler. function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, vi::AbstractVarInfo; kwargs... + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + vi_base::DynamicPPL.AbstractVarInfo; + initial_params=nothing, + kwargs..., ) - # TODO: Technically this only works for `VarInfo` or `ThreadSafeVarInfo{<:VarInfo}`. - # Should we enforce this? - - # Create tuple of samplers - algs = spl.alg.algs - i = 0 - samplers = map(algs) do alg - i += 1 - if i == 1 - prev_alg = algs[end] - else - prev_alg = algs[i - 1] - end - rerun = gibbs_rerun(prev_alg, alg) - selector = DynamicPPL.Selector(Symbol(typeof(alg)), rerun) - Sampler(alg, model, selector) - end + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers - # Add Gibbs to gids for all variables. - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns + # 1. Run the model once to get the varnames present + initial values to condition on. + vi_base = DynamicPPL.VarInfo(model) - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) + # Simple way of setting the initial parameters: set them in the `vi_base` + # if they are given so they propagate to the subset varinfos used by each sampler. + if initial_params !== nothing + vi_base = DynamicPPL.unflatten(vi_base, initial_params) + end - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end - end + # Create the varinfos for each sampler. + varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) + initial_params_all = if initial_params === nothing + fill(nothing, length(varnames)) + else + # Extract from the `vi_base`, which should have the values set correctly from above. + map(vi -> vi[:], varinfos) end - # Compute initial states of the local samplers. - states = map(samplers) do local_spl - # Recompute `vi.logp` if needed. - if local_spl.selector.rerun - vi = last( - DynamicPPL.evaluate!!( - model, vi, DynamicPPL.SamplingContext(rng, local_spl) - ), - ) + # 2. Construct a varinfo for every vn + sampler combo. + states_and_varinfos = map( + samplers, varinfos, initial_params_all + ) do sampler_local, varinfo_local, initial_params_local + # Construct the conditional model. + model_local = make_conditional(model, varinfo_local, varinfos) + + # Take initial step. + new_state_local = last( + AbstractMCMC.step( + rng, + model_local, + sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs..., + ), + ) + + # Return the new state and the invlinked `varinfo`. + vi_local_state = varinfo(new_state_local) + vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) + DynamicPPL.invlink(vi_local_state, sampler_local, model_local) + else + vi_local_state end + return (new_state_local, vi_local_state_linked) + end + + states = map(first, states_and_varinfos) + varinfos = map(last, states_and_varinfos) - # Compute initial state. - _, state = DynamicPPL.initialstep(rng, model, local_spl, vi; kwargs...) + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) + ) - # Update `VarInfo` object. - vi = gibbs_varinfo(model, local_spl, state) + return Transition(model, vi), GibbsState(vi, states) +end - return state +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + state::GibbsState; + kwargs..., +) + alg = spl.alg + samplers = alg.samplers + states = state.states + varinfos = map(varinfo, state.states) + @assert length(samplers) == length(state.states) + + # TODO: move this into a recursive function so we can unroll when reasonable? + for index in 1:length(samplers) + # Take the inner step. + new_state_local, new_varinfo_local = gibbs_step_inner( + rng, model, samplers, states, varinfos, index; kwargs... + ) + + # Update the `states` and `varinfos`. + states = Accessors.setindex(states, new_state_local, index) + varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) end - # Compute initial transition and state. - transition = Transition(model, vi) - state = GibbsState(vi, samplers, states) + # Combine the resulting varinfo objects. + # The last varinfo holds the correctly computed logp. + vi_base = state.vi + + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!( + varinfos, merge(vi_base, first(varinfos)), firstindex(varinfos) + ) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) + ) - return transition, state + return Transition(model, vi), GibbsState(vi, states) end -# Subsequent steps -function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, state::GibbsState; kwargs... -) - # Iterate through each of the samplers. - vi = state.vi - samplers = state.samplers - states = map(samplers, spl.alg.iterations, state.states) do _sampler, iteration, _state - # Recompute `vi.logp` if needed. - if _sampler.selector.rerun - vi = last(DynamicPPL.evaluate!!(model, rng, vi, _sampler)) - end +# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. +function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) + # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide + # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact + # same `selector` as before but now with `rerun` set to `true` if needed. + return Accessors.@set sampler.selector.rerun = true +end + +# Interface we need a sampler to implement to work as a component in a Gibbs sampler. +""" + gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - # Update state of current sampler with updated `VarInfo` object. - current_state = gibbs_state(model, _sampler, _state, vi) +Check if the log-probability of the destination model needs to be recomputed. - # Step through the local sampler. - newstate = current_state - for _ in 1:iteration - _, newstate = AbstractMCMC.step(rng, model, _sampler, newstate; kwargs...) - end +Defaults to `true` +""" +function gibbs_requires_recompute_logprob( + model_dst, sampler_dst, sampler_src, state_dst, state_src +) + return true +end - # Update `VarInfo` object. - vi = gibbs_varinfo(model, _sampler, newstate) +# TODO: Remove `rng`? +function recompute_logprob!!( + rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, state +) + vi = varinfo(state) + # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, + # e.g. log-likelihood in the scenario of `ESS`. + # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. + sampler_rerun = make_rerun_sampler(model, sampler) + # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed + # `varinfo`, even if `varinfo` was linked. + vi_new = last( + DynamicPPL.evaluate!!( + model, + vi, + # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. + DynamicPPL.SamplingContext(rng, sampler_rerun), + ) + ) + # Update the state we're about to use if need be. + # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. + return gibbs_state(model, sampler, state, vi_new) +end + +function gibbs_step_inner( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + samplers, + states, + varinfos, + index; + kwargs..., +) + # Needs to do a a few things. + sampler_local = samplers[index] + state_local = states[index] + varinfo_local = varinfos[index] + + # Make sure that all `varinfos` are linked. + varinfos_invlinked = map(varinfos) do vi + # NOTE: This is immutable linking! + # TODO: Do we need the `istrans` check here or should we just always use `invlink`? + # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 + DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi + end + varinfo_local_invlinked = varinfos_invlinked[index] + + # 1. Create conditional model. + # Construct the conditional model. + # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, + # otherwise we're conditioning on values which are not in the support of the + # distributions. + model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) + + # Extract the previous sampler and state. + sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] + state_previous = states[index == 1 ? length(states) : index - 1] + + # 1. Re-run the sampler if needed. + if gibbs_requires_recompute_logprob( + model_local, sampler_local, sampler_previous, state_local, state_previous + ) + state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) + end - return newstate + # 2. Take step with local sampler. + new_state_local = last( + AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) + ) + + # 3. Extract the new varinfo. + # Return the resulting state and invlinked `varinfo`. + varinfo_local_state = varinfo(new_state_local) + varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) + DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) + else + varinfo_local_state end - return Transition(model, vi), GibbsState(vi, samplers, states) + # TODO: alternatively, we can return `states_new, varinfos_new, index_new` + return (new_state_local, varinfo_local_state_invlinked) end diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl deleted file mode 100644 index fda79315b2..0000000000 --- a/src/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,88 +0,0 @@ -""" - GibbsConditional(sym, conditional) - -A "pseudo-sampler" to manually provide analytical Gibbs conditionals to `Gibbs`. -`GibbsConditional(:x, cond)` will sample the variable `x` according to the conditional `cond`, which -must therefore be a function from a `NamedTuple` of the conditioned variables to a `Distribution`. - - -The `NamedTuple` that is passed in contains all random variables from the model in an unspecified -order, taken from the [`VarInfo`](@ref) object over which the model is run. Scalars and vectors are -stored in their respective shapes. The tuple also contains the value of the conditioned variable -itself, which can be useful, but using it creates something that is not a Gibbs sampler anymore (see -[here](https://github.com/TuringLang/Turing.jl/pull/1275#discussion_r434240387)). - -# Examples - -```julia -α_0 = 2.0 -θ_0 = inv(3.0) -x = [1.5, 2.0] -N = length(x) - -@model function inverse_gdemo(x) - λ ~ Gamma(α_0, θ_0) - σ = sqrt(1 / λ) - m ~ Normal(0, σ) - @. x ~ \$(Normal(m, σ)) -end - -# The conditionals can be formulated in terms of the following statistics: -x_bar = mean(x) # sample mean -s2 = var(x; mean=x_bar, corrected=false) # sample variance -m_n = N * x_bar / (N + 1) - -function cond_m(c) - λ_n = c.λ * (N + 1) - σ_n = sqrt(1 / λ_n) - return Normal(m_n, σ_n) -end - -function cond_λ(c) - α_n = α_0 + (N - 1) / 2 + 1 - β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0) - return Gamma(α_n, inv(β_n)) -end - -m = inverse_gdemo(x) - -sample(m, Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)), 10) -``` -""" -struct GibbsConditional{S,C} - conditional::C - - function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) - end -end - -DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,) - -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - return nothing, vi -end - -function AbstractMCMC.step( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - condvals = DynamicPPL.values_as(DynamicPPL.invlink(vi, model), NamedTuple) - conddist = spl.alg.conditional(condvals) - updated = rand(rng, conddist) - # Setindex allows only vectors in this case. - vi = setindex!!(vi, [updated;], spl) - # Update log joint probability. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - - return nothing, vi -end diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl deleted file mode 100644 index 0f0740f14a..0000000000 --- a/test/experimental/gibbs.jl +++ /dev/null @@ -1,270 +0,0 @@ -module ExperimentalGibbsTests - -using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo -using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo, - check_numerical, two_sample_test -using DynamicPPL -using Random -using Test -using Turing -using Turing.Inference: AdvancedHMC, AdvancedMH -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff - -function check_transition_varnames( - transition::Turing.Inference.Transition, - parent_varnames -) - transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val - [first(vn_and_val)] - end - # Varnames in `transition` should be subsumed by those in `vns`. - for vn in transition_varnames - @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) - end -end - -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::Model) = true - -@testset "Gibbs using `condition`" begin - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end - - samplers = [ - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => NUTS(), - ), - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => HMC(0.01, 4), - ) - ] - - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Experimental.Gibbs( - vns_s => HMC(0.01, 4), - vns_m => MH(), - ), - Turing.Experimental.Gibbs( - vns_s => MH(), - vns_m => HMC(0.01, 4), - ) - ] - ) - end - - @testset "$sampler" for sampler in samplers - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - end - end - - @testset "comparison with 'gold-standard' samples" begin - num_iterations = 1_000 - thinning = 10 - num_chains = 4 - - # Determine initial parameters to make comparison as fair as possible. - posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) - - # Sampler to use for Gibbs components. - sampler_inner = HMC(0.1, 32) - sampler = Turing.Experimental.Gibbs( - vns_s => sampler_inner, - vns_m => sampler_inner, - ) - Random.seed!(42) - chain = sample( - model, - sampler, - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - discard_initial=1_000, - thinning=thinning - ) - - # "Ground truth" samples. - # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. - Random.seed!(42) - chain_true = sample( - model, - NUTS(), - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - thinning=thinning, - ) - - # Perform KS test to ensure that the chains are similar. - xs = Array(chain) - xs_true = Array(chain_true) - for i = 1:size(xs, 2) - @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) - # Let's make sure that the significance level is not too low by - # checking that the KS test fails for some simple transformations. - # TODO: Replace the heuristic below with closed-form implementations - # of the targets, once they are implemented in DynamicPPL. - @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) - end - end - end - end - - @testset "multiple varnames" begin - rng = Random.default_rng() - - @testset "with both `s` and `m` as random" begin - model = gdemo(1.5, 2.0) - vns = (@varname(s), @varname(m)) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - - # `sample` - Random.seed!(42) - chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) - end - - @testset "without `m` as random" begin - model = gdemo(1.5, 2.0) | (m=7 / 6,) - vns = (@varname(s),) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - end - end - - @testset "CSMC + ESS" begin - rng = Random.default_rng() - model = MoGtest_default - alg = Turing.Experimental.Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) - check_MoGtest_default(chain, atol = 0.2) - end - - @testset "CSMC + ESS (usage of implicit varname)" begin - rng = Random.default_rng() - model = MoGtest_default_z_vector - alg = Turing.Experimental.Gibbs( - @varname(z) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) - check_MoGtest_default_z_vector(chain, atol = 0.2) - end - - @testset "externsalsampler" begin - @model function demo_gibbs_external() - m1 ~ Normal() - m2 ~ Normal() - - -1 ~ Normal(m1, 1) - +1 ~ Normal(m1 + m2, 1) - - return (; m1, m2) - end - - model = demo_gibbs_external() - samplers_inner = [ - externalsampler(AdvancedMH.RWMH(1)), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)), - ] - @testset "$(sampler_inner)" for sampler_inner in samplers_inner - sampler = Turing.Experimental.Gibbs( - @varname(m1) => sampler_inner, - @varname(m2) => sampler_inner, - ) - Random.seed!(42) - chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0) - check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1) - end - end -end - -end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 15ec6149c0..4a6e0e9a6d 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -33,15 +33,15 @@ ADUtils.install_tapir && import Tapir PG(10), IS(), MH(), - Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) else ( HMC(0.1, 7; adtype=adbackend), IS(), MH(), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) end for sampler in samplers @@ -85,7 +85,7 @@ ADUtils.install_tapir && import Tapir alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) alg2 = PG(20) - alg3 = Gibbs(PG(30, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg3 = Gibbs(; s=PG(30), m=HMC(0.2, 4; adtype=adbackend)) chn1 = sample(gdemo_default, alg1, 5000; save_state=true) check_gdemo(chn1) @@ -234,7 +234,7 @@ ADUtils.install_tapir && import Tapir smc = SMC() pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x)) + gibbs = Gibbs(; p=HMC(0.2, 3; adtype=adbackend), x=PG(10)) chn_s = sample(testbb(obs), smc, 1000) chn_p = sample(testbb(obs), pg, 2000) @@ -261,7 +261,7 @@ ADUtils.install_tapir && import Tapir return s, m end - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m; adtype=adbackend)) + gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8; adtype=adbackend)) chain = sample(fggibbstest(xs), gibbs, 2) end @testset "new grammar" begin @@ -367,7 +367,7 @@ ADUtils.install_tapir && import Tapir @test all(isone, res_pg[:x]) end @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) + alg = Gibbs(; m=HMC(0.2, 3; adtype=adbackend), s=PG(10)) chn = sample(gdemo_default, alg, 1000) end @testset "vectorization @." begin diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 0a1c23a9eb..da03e686d9 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -38,7 +38,7 @@ using Turing c3 = sample(demodot_default, s1, N) c4 = sample(demodot_default, s2, N) - s3 = Gibbs(ESS(:m), MH(:s)) + s3 = Gibbs(; m=ESS(), s=MH()) c5 = sample(gdemo_default, s3, N) end @@ -52,13 +52,17 @@ using Turing check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1) Random.seed!(100) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) # MoGtest Random.seed!(125) - alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) + alg = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(m2) => ESS(), + ) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain; atol=0.1) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 6868cb5e84..354a195377 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -1,30 +1,65 @@ module GibbsTests -using ..Models: MoGtest_default, gdemo, gdemo_default -using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical +using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo, gdemo_default +using ..NumericalTests: + check_MoGtest_default, + check_MoGtest_default_z_vector, + check_gdemo, + check_numerical, + two_sample_test import ..ADUtils using Distributions: InverseGamma, Normal using Distributions: sample +using DynamicPPL: DynamicPPL using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff using Test: @test, @testset using Turing using Turing: Inference +using Turing.Inference: AdvancedHMC, AdvancedMH using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess ADUtils.install_tapir && import Tapir +function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) + transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val + [first(vn_and_val)] + end + # Varnames in `transition` should be subsumed by those in `parent_varnames`. + for vn in transition_varnames + @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) + end +end + +const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, +} +has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false +has_dot_assume(::DynamicPPL.Model) = true + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "gibbs constructor" begin N = 500 - s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) - s2 = Gibbs(PG(10, :s, :m)) - s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - for s in (s1, s2, s3, s4, s5, s6) + s1 = begin + alg = HMC(0.1, 5, :s, :m; adtype=adbackend) + Gibbs(; s=alg, m=alg) + end + s2 = begin + alg = PG(10) + Gibbs(@varname(s) => alg, @varname(m) => alg) + end + s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) + s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) + s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + for s in (s1, s2, s3, s4, s5, s6, s7) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end @@ -34,22 +69,15 @@ ADUtils.install_tapir && import Tapir c4 = sample(gdemo_default, s4, N) c5 = sample(gdemo_default, s5, N) c6 = sample(gdemo_default, s6, N) + c7 = sample(gdemo_default, s7, N) - # Test gid of each samplers g = Turing.Sampler(s3, gdemo_default) - - _, state = AbstractMCMC.step(Random.default_rng(), gdemo_default, g) - @test state.samplers[1].selector != g.selector - @test state.samplers[2].selector != g.selector - @test state.samplers[1].selector != state.samplers[2].selector - - # run sampler: progress logging should be disabled and - # it should return a Chains object @test sample(gdemo_default, g, N) isa MCMCChains.Chains end + @testset "gibbs inference" begin Random.seed!(100) - alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -57,11 +85,11 @@ ADUtils.install_tapir && import Tapir Random.seed!(100) - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) @@ -71,15 +99,17 @@ ADUtils.install_tapir && import Tapir Random.seed!(200) gibbs = Gibbs( - PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2; adtype=adbackend) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), ) chain = sample(MoGtest_default, gibbs, 10_000) check_MoGtest_default(chain; atol=0.15) Random.seed!(200) for alg in [ - Gibbs((MH(:s), 2), (HMC(0.2, 4, :m; adtype=adbackend), 1)), - Gibbs((MH(:s), 1), (HMC(0.2, 4, :m; adtype=adbackend), 2)), + # The new syntax for specifying a sampler to run twice for one variable. + Gibbs(s => MH(), s => MH(), m => HMC(0.2, 4; adtype=adbackend)), + Gibbs(s => MH(), m => HMC(0.2, 4), m => HMC(0.2, 4); adtype=adbackend), ] chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain; atol=0.15) @@ -113,9 +143,10 @@ ADUtils.install_tapir && import Tapir return nothing end - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) sample(model, alg, 100; callback=callback) end + @testset "dynamic model" begin @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} N = length(y) @@ -136,10 +167,250 @@ ADUtils.install_tapir && import Tapir m[k] ~ Normal(1.0, 1.0) end end - model = imm(randn(100), 1.0) + model = imm(Random.randn(100), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100); - sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m; adtype=adbackend)), 100) + sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) + end + + @testset "Demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + # Run one sampler on variables starting with `s` and another on variables starting with `m`. + vns_s = filter(vns) do vn + DynamicPPL.getsym(vn) == :s + end + vns_m = filter(vns) do vn + DynamicPPL.getsym(vn) == :m + end + + samplers = [ + Turing.Gibbs(vns_s => NUTS(), vns_m => NUTS()), + Turing.Gibbs(vns_s => NUTS(), vns_m => HMC(0.01, 4)), + ] + + if !has_dot_assume(model) + # Add in some MH samplers, which are not compatible with `.~`. + append!( + samplers, + [ + Turing.Gibbs(vns_s => HMC(0.01, 4), vns_m => MH()), + Turing.Gibbs(vns_s => MH(), vns_m => HMC(0.01, 4)), + ], + ) + end + + @testset "$sampler" for sampler in samplers + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler) + ) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler), state + ) + check_transition_varnames(transition, vns) + end + end + + # Run the Gibbs sampler and NUTS on the same model, compare statistics of the + # chains. + @testset "comparison with 'gold-standard' samples" begin + num_iterations = 1_000 + thinning = 10 + num_chains = 4 + + # Determine initial parameters to make comparison as fair as possible. + posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) + initial_params = DynamicPPL.TestUtils.update_values!!( + DynamicPPL.VarInfo(model), + posterior_mean, + DynamicPPL.TestUtils.varnames(model), + )[:] + initial_params = fill(initial_params, num_chains) + + # Sampler to use for Gibbs components. + sampler_inner = HMC(0.1, 32) + sampler = Turing.Gibbs(vns_s => sampler_inner, vns_m => sampler_inner) + Random.seed!(42) + chain = sample( + model, + sampler, + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + discard_initial=1_000, + thinning=thinning, + ) + + # "Ground truth" samples. + # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. + Random.seed!(42) + chain_true = sample( + model, + NUTS(), + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + thinning=thinning, + ) + + # Perform KS test to ensure that the chains are similar. + xs = Array(chain) + xs_true = Array(chain_true) + for i in 1:size(xs, 2) + @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) + # Let's make sure that the significance level is not too low by + # checking that the KS test fails for some simple transformations. + # TODO: Replace the heuristic below with closed-form implementations + # of the targets, once they are implemented in DynamicPPL. + @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) + end + end + end + end + + @testset "multiple varnames" begin + rng = Random.default_rng() + + @testset "with both `s` and `m` as random" begin + model = gdemo(1.5, 2.0) + vns = (@varname(s), @varname(m)) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # `sample` + Random.seed!(42) + chain = sample(model, alg, 10_000; progress=false) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) + end + + @testset "without `m` as random" begin + model = gdemo(1.5, 2.0) | (m=7 / 6,) + vns = (@varname(s),) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + end + end + + @testset "CSMC + ESS" begin + rng = Random.default_rng() + model = MoGtest_default + alg = Turing.Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) + vns = ( + @varname(z1), + @varname(z2), + @varname(z3), + @varname(z4), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(MoGtest_default, alg, 1000; progress=false) + check_MoGtest_default(chain; atol=0.2) + end + + @testset "CSMC + ESS (usage of implicit varname)" begin + rng = Random.default_rng() + model = MoGtest_default_z_vector + alg = Turing.Gibbs( + @varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS() + ) + vns = ( + @varname(z[1]), + @varname(z[2]), + @varname(z[3]), + @varname(z[4]), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(model, alg, 1000; progress=false) + check_MoGtest_default_z_vector(chain; atol=0.2) + end + + @testset "externsalsampler" begin + @model function demo_gibbs_external() + m1 ~ Normal() + m2 ~ Normal() + + -1 ~ Normal(m1, 1) + +1 ~ Normal(m1 + m2, 1) + + return (; m1, m2) + end + + model = demo_gibbs_external() + samplers_inner = [ + externalsampler(AdvancedMH.RWMH(1)), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoForwardDiff()), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff()), + externalsampler( + AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff(; compile=true) + ), + ] + @testset "$(sampler_inner)" for sampler_inner in samplers_inner + sampler = Turing.Gibbs( + @varname(m1) => sampler_inner, @varname(m2) => sampler_inner + ) + Random.seed!(42) + chain = sample( + model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0 + ) + check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1) + end end end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl deleted file mode 100644 index 3f02c75945..0000000000 --- a/test/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,172 +0,0 @@ -module GibbsConditionalTests - -using ..Models: gdemo, gdemo_default -using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils -using Clustering: Clustering -using Distributions: Categorical, InverseGamma, Normal, sample -using ForwardDiff: ForwardDiff -using LinearAlgebra: Diagonal, I -using Random: Random -using ReverseDiff: ReverseDiff -using StableRNGs: StableRNG -using StatsBase: counts -using StatsFuns: StatsFuns -using Test: @test, @testset -using Turing - -ADUtils.install_tapir && import Tapir - -@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends - Random.seed!(1000) - rng = StableRNG(123) - - @testset "gdemo" begin - # We consider the model - # ```math - # s ~ InverseGamma(2, 3) - # m ~ Normal(0, √s) - # xᵢ ~ Normal(m, √s), i = 1, …, N, - # ``` - # with ``N = 2`` observations ``x₁ = 1.5`` and ``x₂ = 2``. - - # The conditionals and posterior can be formulated in terms of the following statistics: - N = 2 - x_mean = 1.75 # sample mean ``∑ xᵢ / N`` - x_var = 0.0625 # sample variance ``∑ (xᵢ - x_bar)^2 / N`` - m_n = 3.5 / 3 # ``∑ xᵢ / (N + 1)`` - - # Conditional distribution - # ```math - # m | s, x ~ Normal(m_n, sqrt(s / (N + 1))) - # ``` - cond_m = let N = N, m_n = m_n - c -> Normal(m_n, sqrt(c.s / (N + 1))) - end - - # Conditional distribution - # ```math - # s | m, x ~ InverseGamma(2 + (N + 1) / 2, 3 + (m^2 + ∑ (xᵢ - m)^2) / 2) = - # InverseGamma(2 + (N + 1) / 2, 3 + m^2 / 2 + N / 2 * (x_var + (x_mean - m)^2)) - # ``` - cond_s = let N = N, x_mean = x_mean, x_var = x_var - c -> InverseGamma( - 2 + (N + 1) / 2, 3 + c.m^2 / 2 + N / 2 * (x_var + (x_mean - c.m)^2) - ) - end - - # Three Gibbs samplers: - # one for each variable fixed to the posterior mean - s_posterior_mean = 49 / 24 - sampler1 = Gibbs( - GibbsConditional(:m, cond_m), - GibbsConditional(:s, _ -> Normal(s_posterior_mean, 0)), - ) - chain = sample(rng, gdemo_default, sampler1, 10_000) - cond_m_mean = mean(cond_m((s=s_posterior_mean,))) - check_numerical(chain, [:m, :s], [cond_m_mean, s_posterior_mean]) - @test all(==(s_posterior_mean), chain[:s][2:end]) - - m_posterior_mean = 7 / 6 - sampler2 = Gibbs( - GibbsConditional(:m, _ -> Normal(m_posterior_mean, 0)), - GibbsConditional(:s, cond_s), - ) - chain = sample(rng, gdemo_default, sampler2, 10_000) - cond_s_mean = mean(cond_s((m=m_posterior_mean,))) - check_numerical(chain, [:m, :s], [m_posterior_mean, cond_s_mean]) - @test all(==(m_posterior_mean), chain[:m][2:end]) - - # and one for both using the conditional - sampler3 = Gibbs(GibbsConditional(:m, cond_m), GibbsConditional(:s, cond_s)) - chain = sample(rng, gdemo_default, sampler3, 10_000) - check_gdemo(chain) - end - - @testset "GMM" begin - Random.seed!(1000) - rng = StableRNG(123) - # We consider the model - # ```math - # μₖ ~ Normal(m, σ_μ), k = 1, …, K, - # zᵢ ~ Categorical(π), i = 1, …, N, - # xᵢ ~ Normal(μ_{zᵢ}, σₓ), i = 1, …, N, - # ``` - # with ``K = 2`` clusters, ``N = 20`` observations, and the following parameters: - K = 2 # number of clusters - π = fill(1 / K, K) # uniform cluster weights - m = 0.5 # prior mean of μₖ - σ²_μ = 4.0 # prior variance of μₖ - σ²_x = 0.01 # observation variance - N = 20 # number of observations - - # We generate data - μ_data = rand(rng, Normal(m, sqrt(σ²_μ)), K) - z_data = rand(rng, Categorical(π), N) - x_data = rand(rng, MvNormal(μ_data[z_data], σ²_x * I)) - - @model function mixture(x) - μ ~ $(MvNormal(fill(m, K), σ²_μ * I)) - z ~ $(filldist(Categorical(π), N)) - x ~ MvNormal(μ[z], $(σ²_x * I)) - return x - end - model = mixture(x_data) - - # Conditional distribution ``z | μ, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_z = let x = x_data, log_π = log.(π), σ_x = sqrt(σ²_x) - c -> begin - dists = map(x) do xi - logp = log_π .+ logpdf.(Normal.(c.μ, σ_x), xi) - return Categorical(StatsFuns.softmax!(logp)) - end - return arraydist(dists) - end - end - - # Conditional distribution ``μ | z, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_μ = let K = K, x_data = x_data, inv_σ²_μ = inv(σ²_μ), inv_σ²_x = inv(σ²_x) - c -> begin - # Convert cluster assignments to one-hot encodings - z_onehot = c.z .== (1:K)' - - # Count number of observations in each cluster - n = vec(sum(z_onehot; dims=1)) - - # Compute mean and variance of the conditional distribution - μ_var = @. inv(inv_σ²_x * n + inv_σ²_μ) - μ_mean = (z_onehot' * x_data) .* inv_σ²_x .* μ_var - - return MvNormal(μ_mean, Diagonal(μ_var)) - end - end - - estimate(chain, var) = dropdims(mean(Array(group(chain, var)); dims=1); dims=1) - function estimatez(chain, var, range) - z = Int.(Array(group(chain, var))) - return map(i -> findmax(counts(z[:, i], range))[2], 1:size(z, 2)) - end - - lμ_data, uμ_data = extrema(μ_data) - - # Compare three Gibbs samplers - sampler1 = Gibbs(GibbsConditional(:z, cond_z), GibbsConditional(:μ, cond_μ)) - sampler2 = Gibbs(GibbsConditional(:z, cond_z), MH(:μ)) - sampler3 = Gibbs(GibbsConditional(:z, cond_z), HMC(0.01, 7, :μ; adtype=adbackend)) - for sampler in (sampler1, sampler2, sampler3) - chain = sample(rng, model, sampler, 10_000) - - μ_hat = estimate(chain, :μ) - lμ_hat, uμ_hat = extrema(μ_hat) - @test isapprox([lμ_data, uμ_data], [lμ_hat, uμ_hat], atol=0.1) - - z_hat = estimatez(chain, :z, 1:2) - ari, _, _, _ = Clustering.randindex(z_data, Int.(z_hat)) - @test isapprox(ari, 1, atol=0.1) - end - end -end - -end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index dde977a6f0..889be13c5b 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -130,9 +130,9 @@ ADUtils.install_tapir && import Tapir @testset "hmcda inference" begin alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend) - # alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m; adtype=adbackend), HMC(0.25, 3, :s; adtype=adbackend)) + # alg2 = Gibbs(; m=HMCDA(200, 0.8, 0.35; adtype=adbackend), s=HMC(0.25, 3; adtype=adbackend)) - # alg3 = Gibbs(HMC(0.25, 3, :m; adtype=adbackend), PG(30, 3, :s)) + # alg3 = Gibbs(; m=HMC(0.25, 3; adtype=adbackend), s=PG(30, 3)) # alg3 = PG(50, 2000) res1 = sample(rng, gdemo_default, alg1, 3000) @@ -147,7 +147,7 @@ ADUtils.install_tapir && import Tapir @testset "hmcda+gibbs inference" begin rng = StableRNG(123) Random.seed!(12345) # particle samplers do not support user-provided `rng` yet - alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) + alg3 = Gibbs(; s=PG(20), m=HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend)) res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) check_gdemo(res3) @@ -200,9 +200,9 @@ ADUtils.install_tapir && import Tapir @test size(c2, 1) == 500 end @testset "AHMC resize" begin - alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s; adtype=adbackend)) - alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s; adtype=adbackend)) - alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s; adtype=adbackend)) + alg1 = Gibbs(; m=PG(10), s=NUTS(100, 0.65; adtype=adbackend)) + alg2 = Gibbs(; m=PG(10), s=HMC(0.1, 3; adtype=adbackend)) + alg3 = Gibbs(; m=PG(10), s=HMCDA(100, 0.65, 0.3; adtype=adbackend)) @test sample(rng, gdemo_default, alg1, 300) isa Chains @test sample(rng, gdemo_default, alg2, 300) isa Chains @test sample(rng, gdemo_default, alg3, 300) isa Chains diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index a01d3dc253..f454db5a05 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -32,7 +32,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - s4 = Gibbs(MH(:m), MH(:s)) + s4 = Gibbs(; m=MH(), s=MH()) c4 = sample(gdemo_default, s4, N) # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) @@ -62,14 +62,16 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) Random.seed!(125) # MH within Gibbs - alg = Gibbs(MH(:m), MH(:s)) + alg = Gibbs(; m=MH(), s=MH()) chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params) check_gdemo(chain; atol=0.1) Random.seed!(125) # MoGtest gibbs = Gibbs( - CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1))) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => MH((:mu1, GKernel(1))), + @varname(mu2) => MH((:mu2, GKernel(1))), ) chain = sample( MoGtest_default, @@ -167,7 +169,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) vc_μ = convert(Array, 1e-4 * I(2)) vc_σ = convert(Array, 1e-4 * I(2)) - alg = Gibbs(MH((:μ, vc_μ)), MH((:σ, vc_σ))) + alg = Gibbs(; μ=MH((:μ, vc_μ)), σ=MH((:σ, vc_σ))) chn = sample( mod, diff --git a/test/runtests.jl b/test/runtests.jl index 1aa8bb635b..ba9aafd2ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,7 +46,6 @@ end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin @timeit_include("mcmc/gibbs.jl") - @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") @@ -65,10 +64,6 @@ end end end - @testset "experimental" begin - @timeit_include("experimental/gibbs.jl") - end - @testset "variational optimisers" begin @timeit_include("variational/optimisers.jl") end diff --git a/test/skipped/explicit_ret.jl b/test/skipped/explicit_ret.jl index c1340464fb..2dabc09bd3 100644 --- a/test/skipped/explicit_ret.jl +++ b/test/skipped/explicit_ret.jl @@ -12,7 +12,7 @@ end mf = test_ex_rt() for alg in - [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(PG(20, 1, :x), HMC(0.2, 3, :y))] + [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(; x=PG(20, 1), y=HMC(0.2, 3))] chn = sample(mf, alg) @test mean(chn[:x]) ≈ 10.0 atol = 0.2 @test mean(chn[:y]) ≈ 5.0 atol = 0.2 From 5a3e4a66cdacd931c023631576829d8401bf5207 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 23 Sep 2024 15:59:40 +0100 Subject: [PATCH 02/70] Remove dead references to experimental --- .github/workflows/Tests.yml | 5 ++--- src/Turing.jl | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 8de296e5ee..770eab9a70 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -22,9 +22,8 @@ jobs: - "mcmc/hmc.jl" - "mcmc/abstractmcmc.jl" - "mcmc/Inference.jl" - - "experimental/gibbs.jl" - "mcmc/ess.jl" - - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl" + - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl" version: - '1.7' - '1' @@ -79,7 +78,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 # TODO: Use julia-actions/julia-runtest when test_args are supported # Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs - # Ref https://github.com/julia-actions/julia-runtest/pull/73 + # Ref https://github.com/julia-actions/julia-runtest/pull/73 - name: Call Pkg.test run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test-args }} - uses: julia-actions/julia-processcoverage@v1 diff --git a/src/Turing.jl b/src/Turing.jl index 8fcee6c185..027c190a3c 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -55,7 +55,6 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation -include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release ########### From 09c739d0c040545017977a749b5f2305ffd5d462 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 23 Sep 2024 16:00:32 +0100 Subject: [PATCH 03/70] Remove mention of experimental from JuliaFormatter conf --- .JuliaFormatter.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index d0e00b45f8..745726d468 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -6,9 +6,7 @@ import_to_using = false # These ignores should be removed once the relevant PRs are merged/closed. ignore = [ # https://github.com/TuringLang/Turing.jl/pull/2231/files - "src/experimental/gibbs.jl", "src/mcmc/abstractmcmc.jl", - "test/experimental/gibbs.jl", "test/test_utils/numerical_tests.jl", # https://github.com/TuringLang/Turing.jl/pull/2218/files "src/mcmc/Inference.jl", From 58ebb259af5adc843790c57d7654e3a656a62687 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 10:05:20 +0100 Subject: [PATCH 04/70] Add tests for deprecated constructor --- src/mcmc/gibbs.jl | 20 ++++++++++++++++++++ test/mcmc/gibbs.jl | 38 ++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index fb05b64757..754451a508 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -279,6 +279,26 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end +# The below constructor only serves to provide backwards compatibility with the constructor +# of the old Gibbs sampler. It is deprecated and will be removed in the future. +function Gibbs(algs::InferenceAlgorithm...) + alg_dict = Dict{Any,InferenceAlgorithm}() + for alg in algs + space = getspace(alg) + space_vns = if (space isa Symbol || space isa VarName) + space + else + tuple((s isa Symbol ? VarName{s}() : s for s in space)...) + end + alg_dict[space_vns] = alg + end + Base.depwarn( + "Specifying which sampler to use with which variable using syntax like `Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. Please use `Gibbs(; x=NUTS(), y=MH())` instead.", + :Gibbs, + ) + return Gibbs(alg_dict) +end + # TODO: Remove when no longer needed. DynamicPPL.getspace(::Gibbs) = () diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 354a195377..dbf4271c18 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -14,7 +14,7 @@ using DynamicPPL: DynamicPPL using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff -using Test: @test, @testset +using Test: @test, @test_deprecated, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -44,16 +44,34 @@ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::DynamicPPL.Model) = true @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @testset "gibbs constructor" begin - N = 500 - s1 = begin - alg = HMC(0.1, 5, :s, :m; adtype=adbackend) - Gibbs(; s=alg, m=alg) - end - s2 = begin - alg = PG(10) - Gibbs(@varname(s) => alg, @varname(m) => alg) + @testset "Deprecated Gibbs constructors" begin + N = 10 + @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) + @test_deprecated s2 = Gibbs(PG(10, :s, :m)) + @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) + for s in (s1, s2, s3, s4, s5, s6) + @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end + + # Check that the samplers work despite using the deprecated constructor. + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + + g = Turing.Sampler(s3, gdemo_default) + @test sample(gdemo_default, g, N) isa MCMCChains.Chains + end + + @testset "Gibbs constructors" begin + N = 10 + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5, :s, :m; adtype=adbackend)) + s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) From 771573293eb42dbf45226ac4689bd04ae8352855 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 14:43:21 +0100 Subject: [PATCH 05/70] Fix deprecated Gibbs constructors. Add HISTORY entry. --- HISTORY.md | 16 ++++++++++++++++ src/mcmc/gibbs.jl | 35 +++++++++++++++++++++++++---------- test/mcmc/gibbs.jl | 4 +++- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 5b1cad0ede..11d08e12ca 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# Release 0.35.0 + +## Breaking changes + +0.35.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely. + +The new Gibbs sampler supports the same user-facing interface as the old one. However, given +that the internals of it having been completely rewritten in a very different manner, there +may be accidental breakage that we haven't anticipated. Please report any you find. + +`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. + +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. + +Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`. + # Release 0.33.0 ## Breaking changes diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 754451a508..445bc433f8 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -279,24 +279,39 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end -# The below constructor only serves to provide backwards compatibility with the constructor -# of the old Gibbs sampler. It is deprecated and will be removed in the future. +# The below two constructors only provide backwards compatibility with the constructor of +# the old Gibbs sampler. They are deprecated and will be removed in the future. function Gibbs(algs::InferenceAlgorithm...) - alg_dict = Dict{Any,InferenceAlgorithm}() - for alg in algs + varnames = map(algs) do alg space = getspace(alg) - space_vns = if (space isa Symbol || space isa VarName) + if (space isa VarName) space + elseif (space isa Symbol) + VarName{space}() else tuple((s isa Symbol ? VarName{s}() : s for s in space)...) end - alg_dict[space_vns] = alg end - Base.depwarn( - "Specifying which sampler to use with which variable using syntax like `Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. Please use `Gibbs(; x=NUTS(), y=MH())` instead.", - :Gibbs, + msg = ( + "Specifying which sampler to use with which variable using syntax like " * + "`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " * + "Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " * + "counts for different subsamplers, use e.g. " * + "`Gibbs(@varname(x) => NUTS(), @varname(x) => NUTS(), @varname(y) => MH())`" ) - return Gibbs(alg_dict) + Base.depwarn(msg, :Gibbs) + return Gibbs(varnames, map(wrap_algorithm_maybe, algs)) +end + +function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) + algs = Iterators.map(first, algs_with_iters) + iters = Iterators.map(last, algs_with_iters) + algs_duplicated = Iterators.flatten(( + Iterators.repeated(alg, iter) for (alg, iter) in zip(algs, iters) + )) + # This calls the other deprecated constructor from above, hence no need for a depwarn + # here. + return Gibbs(algs_duplicated...) end # TODO: Remove when no longer needed. diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index dbf4271c18..9162c6cf8b 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -52,7 +52,8 @@ has_dot_assume(::DynamicPPL.Model) = true @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - for s in (s1, s2, s3, s4, s5, s6) + @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) + for s in (s1, s2, s3, s4, s5, s6, s7) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end @@ -63,6 +64,7 @@ has_dot_assume(::DynamicPPL.Model) = true sample(gdemo_default, s4, N) sample(gdemo_default, s5, N) sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains From 672f7d986bcacc3fffc438de14b7581ce839ccc2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 14:43:38 +0100 Subject: [PATCH 06/70] Bump version to 0.35.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5f5c86b04e..e22c4f4ae7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.34.1" +version = "0.35.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 7bf5abefde676b5aad9f293319ee7eac49f047ba Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 15:39:50 +0100 Subject: [PATCH 07/70] Add Gibbs constructor test for repeat samplers --- test/mcmc/gibbs.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 9162c6cf8b..5082a5f4fa 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -79,17 +79,25 @@ has_dot_assume(::DynamicPPL.Model) = true s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) - for s in (s1, s2, s3, s4, s5, s6, s7) + s8 = begin + hmc = HMC(0.1, 5; adtype=adbackend) + pg = PG(10) + vns = @varname(s) + vnm = @varname(m) + Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) + end + for s in (s1, s2, s3, s4, s5, s6, s7, s8) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end - c1 = sample(gdemo_default, s1, N) - c2 = sample(gdemo_default, s2, N) - c3 = sample(gdemo_default, s3, N) - c4 = sample(gdemo_default, s4, N) - c5 = sample(gdemo_default, s5, N) - c6 = sample(gdemo_default, s6, N) - c7 = sample(gdemo_default, s7, N) + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) + sample(gdemo_default, s8, N) g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains From 85bcfa51cf048c5f73054a136fde37e21d7e99f2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 15:57:31 +0100 Subject: [PATCH 08/70] Fix typo in test/mcmc/ess.jl --- test/mcmc/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index da03e686d9..8d9697d9ad 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -61,7 +61,7 @@ using Turing alg = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), - @varname(m2) => ESS(), + @varname(mu2) => ESS(), ) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain; atol=0.1) From 6f9679ac659210276fd1cb9acdc2a728573699e3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 24 Sep 2024 16:11:34 +0100 Subject: [PATCH 09/70] Use provided rng to initialise VarInfo in Gibbs --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 445bc433f8..571d694e33 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -339,7 +339,7 @@ function DynamicPPL.initialstep( samplers = alg.samplers # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(model) + vi_base = DynamicPPL.VarInfo(rng, model) # Simple way of setting the initial parameters: set them in the `vi_base` # if they are given so they propagate to the subset varinfos used by each sampler. From f247ad997a26151f42b180c1e2121df76d9ca6d5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 8 Oct 2024 15:01:02 +0100 Subject: [PATCH 10/70] Fix a typo in GibbsContext --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 571d694e33..4986ff87d8 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -15,7 +15,7 @@ struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.Abstra context::Ctx end -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) +GibbsContext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context From d19afe18a971b6c623acea3aefff48ac59498911 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 11 Oct 2024 16:55:08 +0100 Subject: [PATCH 11/70] Fix the Gibbs sampler --- Project.toml | 2 +- src/mcmc/gibbs.jl | 182 +++++++++++++++++----------------------------- 2 files changed, 66 insertions(+), 118 deletions(-) diff --git a/Project.toml b/Project.toml index e22c4f4ae7..e23cb5f0be 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ TuringOptimExt = "Optim" [compat] ADTypes = "0.2, 1" -AbstractMCMC = "5.2" +AbstractMCMC = "5.5" Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" AdvancedMH = "0.8" diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 4986ff87d8..7581eeedd4 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -173,42 +173,14 @@ function condition_gibbs(model::DynamicPPL.Model, values...) return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) end -""" - make_conditional_model(model, varinfo, varinfos) - -Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. - -# Examples -```julia-repl -julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); - -julia> # A separate varinfo for each variable in `model`. - varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); - -julia> # The varinfo we want to NOT condition on. - target_varinfo = first(varinfos); - -julia> # Results in a model with only `m` conditioned. - conditioned_model = make_conditional(model, target_varinfo, varinfos); - -julia> result = conditioned_model(); - -julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` -true - -julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` -true -``` -""" function make_conditional( - model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos + model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) - # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) -end -# Assumes the ones given are the ones to condition on. -function make_conditional(model::DynamicPPL.Model, varinfos) - return condition_gibbs(model, varinfos...) + not_target_variables = filter( + x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) + ) + vi_filtered = subset(varinfo, not_target_variables) + return condition_gibbs(model, vi_filtered) end # HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` @@ -219,13 +191,15 @@ wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) """ gibbs_state(model, sampler, state, varinfo) -Return an updated state, taking into account the variables sampled by other Gibbs components. +Return an updated state for a component sampler. + +This takes into account changes caused by other Gibbs components. # Arguments - `model`: model targeted by the Gibbs sampler. - `sampler`: the sampler for this Gibbs component. - `state`: the state of `sampler` computed in the previous iteration. -- `varinfo`: the variables, including the ones sampled by other Gibbs components. +- `varinfo`: the current values of the variables relevant for this sampler. """ gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo function gibbs_state(model, sampler, state::PGState, varinfo::AbstractVarInfo) @@ -237,12 +211,13 @@ function gibbs_state( model::Model, spl::Sampler{<:Hamiltonian}, state::HMCState, varinfo::AbstractVarInfo ) # Update hamiltonian - θ_old = varinfo[spl] - hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_old)) + θ_new = varinfo[:] + hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_new)) + # Update the parameter values in `state.z`. # TODO: Avoid mutation - resize!(state.z.θ, length(θ_old)) - state.z.θ .= θ_old + resize!(state.z.θ, length(θ_new)) + state.z.θ .= θ_new z = state.z return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) @@ -348,55 +323,41 @@ function DynamicPPL.initialstep( end # Create the varinfos for each sampler. - varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) + local_varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) initial_params_all = if initial_params === nothing fill(nothing, length(varnames)) else # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], varinfos) + map(vi -> vi[:], local_varinfos) end # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map( - samplers, varinfos, initial_params_all - ) do sampler_local, varinfo_local, initial_params_local + states = [] + for (varnames_local, sampler_local, initial_params_local) in + zip(varnames, samplers, initial_params_all) # Construct the conditional model. - model_local = make_conditional(model, varinfo_local, varinfos) + model_local = make_conditional(model, _maybevec(varnames_local), vi_base) # Take initial step. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local; - # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. - # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, - kwargs..., - ), + _, new_state_local = AbstractMCMC.step( + rng, + model_local, + sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs..., ) - - # Return the new state and the invlinked `varinfo`. - vi_local_state = varinfo(new_state_local) - vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink(vi_local_state, sampler_local, model_local) + vi_local = varinfo(new_state_local) + vi_local = if DynamicPPL.istrans(vi_local) + DynamicPPL.invlink(vi_local, sampler_local, model_local) else - vi_local_state + vi_local end - return (new_state_local, vi_local_state_linked) + vi_base = merge(vi_base, vi_local) + push!(states, new_state_local) end - - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) - ) - - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi_base), GibbsState(vi_base, states) end function AbstractMCMC.step( @@ -406,37 +367,23 @@ function AbstractMCMC.step( state::GibbsState; kwargs..., ) + vi = varinfo(state) alg = spl.alg + varnames = alg.varnames samplers = alg.samplers states = state.states - varinfos = map(varinfo, state.states) @assert length(samplers) == length(state.states) # TODO: move this into a recursive function so we can unroll when reasonable? for index in 1:length(samplers) # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, model, samplers, states, varinfos, index; kwargs... + vi, new_state_local = gibbs_step_inner( + rng, model, varnames, samplers, states, vi, index; kwargs... ) - # Update the `states` and `varinfos`. + # Update the `states` states = Accessors.setindex(states, new_state_local, index) - varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, merge(vi_base, first(varinfos)), firstindex(varinfos) - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) - ) - return Transition(model, vi), GibbsState(vi, states) end @@ -486,40 +433,50 @@ function recompute_logprob!!( return gibbs_state(model, sampler, state, vi_new) end +AbstractMCMC.setparams!!(::VarInfo, vi::VarInfo) = vi +function AbstractMCMC.setparams!!(state, vi::VarInfo) + # In the fallback implementation we guess that `state` has a field called `vi` we can + # set. Fingers crossed! + try + return Accessors.set(state, Accessors.PropertyLens{:vi}(), vi) + catch + error( + "Unable to set `state.vi` for a $(typeof(state)). " * + "Consider writing a method for setparams!! for this type.", + ) + end +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, + varnames, samplers, states, - varinfos, + vi, index; kwargs..., ) # Needs to do a a few things. sampler_local = samplers[index] state_local = states[index] - varinfo_local = varinfos[index] - - # Make sure that all `varinfos` are linked. - varinfos_invlinked = map(varinfos) do vi - # NOTE: This is immutable linking! - # TODO: Do we need the `istrans` check here or should we just always use `invlink`? - # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 - DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - end - varinfo_local_invlinked = varinfos_invlinked[index] + varnames_local = _maybevec(varnames[index]) + + vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi # 1. Create conditional model. # Construct the conditional model. # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, # otherwise we're conditioning on values which are not in the support of the # distributions. - model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) + model_local = make_conditional(model, varnames_local, vi) + varinfo_local = subset(vi, varnames_local) # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] state_previous = states[index == 1 ? length(states) : index - 1] + state_local = AbstractMCMC.setparams!!(state_local, varinfo_local) # 1. Re-run the sampler if needed. if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous @@ -532,15 +489,6 @@ function gibbs_step_inner( AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) ) - # 3. Extract the new varinfo. - # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = varinfo(new_state_local) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - - # TODO: alternatively, we can return `states_new, varinfos_new, index_new` - return (new_state_local, varinfo_local_state_invlinked) + new_vi = merge(vi, varinfo(new_state_local)) + return new_vi, new_state_local end From 19598c4bd8d7f15db86fa1035228308307443e7e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 17 Oct 2024 17:17:19 +0100 Subject: [PATCH 12/70] Fix the Gibbs sampler more --- src/mcmc/abstractmcmc.jl | 2 +- src/mcmc/gibbs.jl | 262 ++++++++++++++++++++++++++------------- 2 files changed, 180 insertions(+), 84 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 7e6c64d110..c6ca61a9c4 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -65,7 +65,7 @@ function recompute_logprob!!( rng::Random.AbstractRNG, # TODO: Do we need the `rng` here? model::DynamicPPL.Model, sampler::DynamicPPL.Sampler{<:ExternalSampler}, - state, + state, # TODO(mhauru) Could we type constrain this to TuringState? ) # Re-using the log-density function from the `state` and updating only the `model` field, # since the `model` might now contain different conditioning values. diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 7581eeedd4..aa01a3b1fa 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -127,52 +127,33 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) end """ - condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) + condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) -Return a `GibbsContext` with the given values treated as conditioned. - -# Arguments -- `context::DynamicPPL.AbstractContext`: The context to condition. -- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. - If multiple values are provided, we recursively condition on each of them. +Return a `GibbsContext` with the values extracted from the given `varinfo` treated as +conditioned. """ -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. function condition_gibbs( - context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict} + context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo ) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs(condition_gibbs(context, value), values...) + # TODO(mhauru) Maybe use preferred_value_type to return NamedTuples in some cases. + # If not, then remove preferred_value_type. + vals = DynamicPPL.OrderedDict(k => varinfo[k] for k in keys(varinfo)) + return GibbsContext(vals, context) end -# For `DynamicPPL.AbstractVarInfo` we just extract the values. """ - condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) + make_conditional(model, target_variables, varinfo) -Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. -""" -function condition_gibbs( - context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo -) - return condition_gibbs( - context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)) - ) -end -function condition_gibbs( - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo, - varinfos::DynamicPPL.AbstractVarInfo..., -) - return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) -end -# Allow calling this on a `DynamicPPL.Model` directly. -function condition_gibbs(model::DynamicPPL.Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end +Return a new, conditioned model for a component of a Gibbs sampler. +# Arguments +- `model::DynamicPPL.Model`: The model to condition. +- `target_variables::AbstractVector{<:VarName}`: The target variables of the component +sampler. These will _not_ conditioned. +- `varinfo::DynamicPPL.AbstractVarInfo`: Values for all variables in the model. All the +values in `varinfo` but not in `target_variables` will be conditioned to the values they +have in `varinfo`. +""" function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) @@ -180,7 +161,8 @@ function make_conditional( x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) ) vi_filtered = subset(varinfo, not_target_variables) - return condition_gibbs(model, vi_filtered) + gibbs_context = condition_gibbs(model.context, vi_filtered) + return DynamicPPL.contextualize(model, gibbs_context) end # HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` @@ -188,41 +170,6 @@ end wrap_algorithm_maybe(x) = x wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) -""" - gibbs_state(model, sampler, state, varinfo) - -Return an updated state for a component sampler. - -This takes into account changes caused by other Gibbs components. - -# Arguments -- `model`: model targeted by the Gibbs sampler. -- `sampler`: the sampler for this Gibbs component. -- `state`: the state of `sampler` computed in the previous iteration. -- `varinfo`: the current values of the variables relevant for this sampler. -""" -gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo -function gibbs_state(model, sampler, state::PGState, varinfo::AbstractVarInfo) - return PGState(varinfo, state.rng) -end - -# Update state in Gibbs sampling -function gibbs_state( - model::Model, spl::Sampler{<:Hamiltonian}, state::HMCState, varinfo::AbstractVarInfo -) - # Update hamiltonian - θ_new = varinfo[:] - hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_new)) - - # Update the parameter values in `state.z`. - # TODO: Avoid mutation - resize!(state.z.θ, length(θ_new)) - state.z.θ .= θ_new - z = state.z - - return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) -end - """ Gibbs @@ -349,6 +296,7 @@ function DynamicPPL.initialstep( kwargs..., ) vi_local = varinfo(new_state_local) + # TODO(mhauru) Can we remove the invlinking? vi_local = if DynamicPPL.istrans(vi_local) DynamicPPL.invlink(vi_local, sampler_local, model_local) else @@ -428,25 +376,159 @@ function recompute_logprob!!( DynamicPPL.SamplingContext(rng, sampler_rerun), ) ) - # Update the state we're about to use if need be. - # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - return gibbs_state(model, sampler, state, vi_new) + return setlogp!!(state, vi_new.logp[]) end -AbstractMCMC.setparams!!(::VarInfo, vi::VarInfo) = vi -function AbstractMCMC.setparams!!(state, vi::VarInfo) +# TODO(mhauru) Would really like to type constraint this to something like AbstractMCMCState +# if such a thing existed. +function DynamicPPL.setlogp!!(state, logp) + try + new_vi = setlogp!!(state.vi, logp) + if new_vi !== state.vi + return Accessors.set(state, Accessors.PropertyLens{:vi}(), new_vi) + else + return state + end + catch + error( + "Unable to set `state.vi` for a $(typeof(state)). " * + "Consider writing a method for `setlogp!!` for this type.", + ) + end +end + +function DynamicPPL.setlogp!!(state::TuringState, logp) + return TuringState(setlogp!!(state.state, logp), logp) +end + +# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? +# The current list is a guess, but I think some might be unnecessary. +""" + reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) + +Return an updated state for a component sampler. + +This takes into account changes caused by other Gibbs components. The default implementation +is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to +do, a method should be implemented for the specific type of `state`. + +# Arguments +- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not +sampled by this component sampler have been conditioned with a `GibbsContext`. +- `sampler::DynamicPPL.Sampler`: The current component sampler. +- `state`: The state of this component sampler from its previous iteration. +- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables +sampled by this component sampler. +- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain. +- `state_previous`: The state returned by the previous sampler. + +# Returns +An updated state of the same type as `state`. It should have variables set to the values in +`varinfo`, and any other relevant updates done. +""" +function reset_state!!( + model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous +) # In the fallback implementation we guess that `state` has a field called `vi` we can # set. Fingers crossed! try - return Accessors.set(state, Accessors.PropertyLens{:vi}(), vi) + return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo) catch error( "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for setparams!! for this type.", + "Consider writing a method for reset_state!! for this type.", ) end end +function reset_state!!( + model, + sampler, + state::AbstractVarInfo, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + return varinfo +end + +function reset_state!!( + model, + sampler, + state::TuringState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + new_inner_state = reset_state!!( + model, sampler, state.state, varinfo, sampler_previous, state_previous + ) + return TuringState(new_inner_state, state.logdensity) +end + +function reset_state!!( + model, + sampler, + state::HMCState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + θ_new = varinfo[:] + hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new)) + + # Update the parameter values in `state.z`. + # TODO: Avoid mutation + z = state.z + resize!(z.θ, length(θ_new)) + z.θ .= θ_new + return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) +end + +function reset_state!!( + model, + sampler, + state::AdvancedHMC.HMCState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + hamiltonian = AdvancedHMC.Hamiltonian( + state.metric, DynamicPPL.LogDensityFunction(model) + ) + θ_new = varinfo[:] + # Set the momentum to zero, since we have no idea what it should be at the new parameter + # values. + return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, θ_new, zero(θ_new) + ) +end + +function reset_state!!( + model, + sampler, + state::AdvancedMH.Transition, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + # TODO(mhauru) Setting the last argument like this seems a bit suspect, since the + # current values for the parameters might not have come from this sampler at all. + # I don't see a better way though. + return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted) +end + +function reset_state!!( + model, + sampler, + state::PGState, + varinfo::AbstractVarInfo, + sampler_previous, + state_previous, +) + return PGState(varinfo, state.rng) +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -462,6 +544,7 @@ function gibbs_step_inner( state_local = states[index] varnames_local = _maybevec(varnames[index]) + # TODO(mhauru) Can we remove the invlinking? vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi # 1. Create conditional model. @@ -471,13 +554,24 @@ function gibbs_step_inner( # distributions. model_local = make_conditional(model, varnames_local, vi) varinfo_local = subset(vi, varnames_local) + # If the varinfo of the previous state from this sampler is linked, we should link the + # new varinfo too. + if DynamicPPL.istrans(varinfo(state_local)) + varinfo_local = DynamicPPL.link(varinfo_local, sampler_local, model_local) + end # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] state_previous = states[index == 1 ? length(states) : index - 1] - state_local = AbstractMCMC.setparams!!(state_local, varinfo_local) - # 1. Re-run the sampler if needed. + state_local = reset_state!!( + model_local, + sampler_local, + state_local, + varinfo_local, + sampler_previous, + state_previous, + ) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) @@ -489,6 +583,8 @@ function gibbs_step_inner( AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) ) - new_vi = merge(vi, varinfo(new_state_local)) + new_vi_local = varinfo(new_state_local) + new_vi = merge(vi, new_vi_local) + new_vi = setlogp!!(new_vi, new_vi_local.logp[]) return new_vi, new_state_local end From 71f26ca09dbd7128b0a9e75bea66b734f440fae3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 18 Oct 2024 11:09:01 +0100 Subject: [PATCH 13/70] Remove mentions of old Gibbs sampler from MH docs Co-authored-by: Penelope Yong --- src/mcmc/mh.jl | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index ffc064eb12..edec843654 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -20,7 +20,6 @@ Construct a Metropolis-Hastings algorithm. The arguments `space` can be - Blank (i.e. `MH()`), in which case `MH` defaults to using the prior for each parameter as the proposal distribution. -- A set of one or more symbols to sample with `MH` in conjunction with `Gibbs`, i.e. `Gibbs(MH(:m), PG(10, :s))` - An iterable of pairs or tuples mapping a `Symbol` to a `AdvancedMH.Proposal`, `Distribution`, or `Function` that generates returns a conditional proposal distribution. - A covariance matrix to use as for mean-zero multivariate normal proposals. @@ -41,22 +40,6 @@ chain = sample(gdemo(1.5, 2.0), MH(), 1_000) mean(chain) ``` -Alternatively, you can specify particular parameters to sample if you want to combine sampling -from multiple samplers: - -```julia -@model function gdemo(x, y) - s² ~ InverseGamma(2,3) - m ~ Normal(0, sqrt(s²)) - x ~ Normal(m, sqrt(s²)) - y ~ Normal(m, sqrt(s²)) -end - -# Samples s with MH and m with PG -chain = sample(gdemo(1.5, 2.0), Gibbs(MH(:s), PG(10, :m)), 1_000) -mean(chain) -``` - Using custom distributions defaults to using static MH: ```julia From d0f57acf249abc2c4463f3fe93fc8b07d3738472 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 21 Oct 2024 14:29:42 +0100 Subject: [PATCH 14/70] Bump DPPL to 0.28.6 --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e23cb5f0be..1b0003ab2b 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.28.2" +DynamicPPL = "0.28.6" Compat = "4.15.0" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" diff --git a/test/Project.toml b/test/Project.toml index 67292d2af5..7d463eb8a6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -45,7 +45,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.28" +DynamicPPL = "0.28.6" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" From 74b57e716915d5d01b51c3588eda79583d1b7ace Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 11:32:13 +0000 Subject: [PATCH 15/70] Redesign GibbsContext, work in progress --- Project.toml | 2 +- src/mcmc/gibbs.jl | 173 ++++++++++++++++++++++++++++------------------ 2 files changed, 107 insertions(+), 68 deletions(-) diff --git a/Project.toml b/Project.toml index c9b9c71dec..ab43ccbc39 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.30" +DynamicPPL = "0.30.1" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index aa01a3b1fa..21e062fd15 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -10,44 +10,80 @@ # rather than only for the "true" observations. # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values +struct GibbsContext{ + VNs,Values,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext +} <: DynamicPPL.AbstractContext + target_varnames::VNs + conditioned_values::Values + global_varinfo::GVI context::Ctx end -GibbsContext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) +function GibbsContext(target_varnames, conditioned_values, global_varinfo) + return GibbsContext( + target_varnames, conditioned_values, global_varinfo, DynamicPPL.DefaultContext() + ) +end DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) - return GibbsContext(context.values, childcontext) + return GibbsContext( + context.target_varnames, + context.conditioned_values, + Ref(context.global_varinfo[]), + childcontext, + ) end # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.hasvalue(context.values, vn) + return DynamicPPL.hasvalue(context.conditioned_values, vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(has_conditioned_gibbs, context), vns) end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.getvalue(context.values, vn) + return DynamicPPL.getvalue(context.conditioned_values, vn) end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) end +function is_target_varname(context::GibbsContext, vn::VarName) + return Iterators.any( + Iterators.map(target -> subsumes(target, vn), context.target_varnames) + ) +end + # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vn) value = get_conditioned_gibbs(context, vn) + # TODO(mhauru) Is the call to logpdf correct if context.context is not + # DefaultContext? return value, logpdf(right, value), vi + elseif is_target_varname(context, vn) + # Fall back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, lp, new_global_vi = DynamicPPL.tilde_assume( + DynamicPPL.SamplingContext( + DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) + ), + right, + vn, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) end function DynamicPPL.tilde_assume( @@ -56,13 +92,30 @@ function DynamicPPL.tilde_assume( # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vn) value = get_conditioned_gibbs(context, vn) + # TODO(mhauru) Is the call to logpdf correct if context.context is not + # DefaultContext? return value, logpdf(right, value), vi + elseif is_target_varname(context, vn) + # Fall back to the default behavior. + return DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, lp, new_global_vi = DynamicPPL.tilde_assume( + DynamicPPL.SamplingContext( + rng, DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) + ), + right, + vn, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) end # Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. @@ -126,21 +179,6 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end -""" - condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - -Return a `GibbsContext` with the values extracted from the given `varinfo` treated as -conditioned. -""" -function condition_gibbs( - context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo -) - # TODO(mhauru) Maybe use preferred_value_type to return NamedTuples in some cases. - # If not, then remove preferred_value_type. - vals = DynamicPPL.OrderedDict(k => varinfo[k] for k in keys(varinfo)) - return GibbsContext(vals, context) -end - """ make_conditional(model, target_variables, varinfo) @@ -157,12 +195,15 @@ have in `varinfo`. function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) + # We want to condition all the variables in keys(varinfo) that are not subsumed by any + # of the target variables. not_target_variables = filter( x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) ) vi_filtered = subset(varinfo, not_target_variables) - gibbs_context = condition_gibbs(model.context, vi_filtered) - return DynamicPPL.contextualize(model, gibbs_context) + vals = DynamicPPL.OrderedDict(k => vi_filtered[k] for k in keys(vi_filtered)) + gibbs_context = GibbsContext(target_variables, vals, Ref(varinfo), model.context) + return DynamicPPL.contextualize(model, gibbs_context), gibbs_context end # HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` @@ -261,29 +302,21 @@ function DynamicPPL.initialstep( samplers = alg.samplers # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(rng, model) - - # Simple way of setting the initial parameters: set them in the `vi_base` - # if they are given so they propagate to the subset varinfos used by each sampler. + vi = DynamicPPL.VarInfo(rng, model) if initial_params !== nothing - vi_base = DynamicPPL.unflatten(vi_base, initial_params) - end - - # Create the varinfos for each sampler. - local_varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) - initial_params_all = if initial_params === nothing - fill(nothing, length(varnames)) - else - # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], local_varinfos) + vi = DynamicPPL.unflatten(vi, initial_params) end - # 2. Construct a varinfo for every vn + sampler combo. + # Initialise each component sampler in turn, collect all their states. states = [] - for (varnames_local, sampler_local, initial_params_local) in - zip(varnames, samplers, initial_params_all) + for (varnames_local, sampler_local) in zip(varnames, samplers) + varnames_local = _maybevec(varnames_local) + # Get the initial values for this component sampler. + vi_local = DynamicPPL.subset(vi, varnames_local) + initial_params_local = initial_params === nothing ? nothing : vi_local[:] + # Construct the conditional model. - model_local = make_conditional(model, _maybevec(varnames_local), vi_base) + model_local, context_local = make_conditional(model, varnames_local, vi) # Take initial step. _, new_state_local = AbstractMCMC.step( @@ -295,17 +328,21 @@ function DynamicPPL.initialstep( initial_params=initial_params_local, kwargs..., ) - vi_local = varinfo(new_state_local) + new_vi_local = varinfo(new_state_local) # TODO(mhauru) Can we remove the invlinking? - vi_local = if DynamicPPL.istrans(vi_local) - DynamicPPL.invlink(vi_local, sampler_local, model_local) + new_vi_local = if DynamicPPL.istrans(new_vi_local) + DynamicPPL.invlink(new_vi_local, sampler_local, model_local) else - vi_local + new_vi_local end - vi_base = merge(vi_base, vi_local) + # This merges in any new variables that were introduced during the step, but that + # were not in the domain of the current sampler. + vi = merge(vi, context_local.global_varinfo[]) + # This merges the latest values for all the variables in the current sampler. + vi = merge(vi, new_vi_local) push!(states, new_state_local) end - return Transition(model, vi_base), GibbsState(vi_base, states) + return Transition(model, vi), GibbsState(vi, states) end function AbstractMCMC.step( @@ -328,8 +365,6 @@ function AbstractMCMC.step( vi, new_state_local = gibbs_step_inner( rng, model, varnames, samplers, states, vi, index; kwargs... ) - - # Update the `states` states = Accessors.setindex(states, new_state_local, index) end return Transition(model, vi), GibbsState(vi, states) @@ -379,7 +414,7 @@ function recompute_logprob!!( return setlogp!!(state, vi_new.logp[]) end -# TODO(mhauru) Would really like to type constraint this to something like AbstractMCMCState +# TODO(mhauru) Would really like to type constrain this to something like AbstractMCMCState # if such a thing existed. function DynamicPPL.setlogp!!(state, logp) try @@ -441,6 +476,8 @@ function reset_state!!( end end +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `reset_state!!`. function reset_state!!( model, sampler, @@ -539,7 +576,6 @@ function gibbs_step_inner( index; kwargs..., ) - # Needs to do a a few things. sampler_local = samplers[index] state_local = states[index] varnames_local = _maybevec(varnames[index]) @@ -547,13 +583,10 @@ function gibbs_step_inner( # TODO(mhauru) Can we remove the invlinking? vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - # 1. Create conditional model. - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varnames_local, vi) + # Construct the conditional model and the varinfo that this sampler should use. + model_local, context_local = make_conditional(model, varnames_local, vi) varinfo_local = subset(vi, varnames_local) + # TODO(mhauru) Can we remove the below, if get rid of all the invlinking? # If the varinfo of the previous state from this sampler is linked, we should link the # new varinfo too. if DynamicPPL.istrans(varinfo(state_local)) @@ -564,6 +597,8 @@ function gibbs_step_inner( sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] state_previous = states[index == 1 ? length(states) : index - 1] + # Set the state of the current sampler, accounting for any changes made by other + # samplers. state_local = reset_state!!( model_local, sampler_local, @@ -578,13 +613,17 @@ function gibbs_step_inner( state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) end - # 2. Take step with local sampler. + # Take a step with the local sampler. new_state_local = last( AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) ) new_vi_local = varinfo(new_state_local) - new_vi = merge(vi, new_vi_local) + # This merges in any new variables that were introduced during the step, but that + # were not in the domain of the current sampler. + new_vi = merge(vi, context_local.global_varinfo[]) + # This merges the latest values for all the variables in the current sampler. + new_vi = merge(new_vi, new_vi_local) new_vi = setlogp!!(new_vi, new_vi_local.logp[]) return new_vi, new_state_local end From b16daf5eff519f08e40b712a0aa1cfc3a95146f0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 15:34:33 +0000 Subject: [PATCH 16/70] Fixing new Gibbs, adding a broken test --- src/mcmc/gibbs.jl | 132 ++++++++++++++++++++++++++------------------- test/mcmc/gibbs.jl | 35 ++++++++++-- 2 files changed, 110 insertions(+), 57 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 21e062fd15..21d71d9b28 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -10,42 +10,35 @@ # rather than only for the "true" observations. # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{ - VNs,Values,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext -} <: DynamicPPL.AbstractContext +struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: + DynamicPPL.AbstractContext target_varnames::VNs - conditioned_values::Values global_varinfo::GVI context::Ctx end -function GibbsContext(target_varnames, conditioned_values, global_varinfo) - return GibbsContext( - target_varnames, conditioned_values, global_varinfo, DynamicPPL.DefaultContext() - ) +function GibbsContext(target_varnames, global_varinfo) + return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext()) end DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) return GibbsContext( - context.target_varnames, - context.conditioned_values, - Ref(context.global_varinfo[]), - childcontext, + context.target_varnames, Ref(context.global_varinfo[]), childcontext ) end # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.hasvalue(context.conditioned_values, vn) + return DynamicPPL.haskey(context.global_varinfo[], vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(has_conditioned_gibbs, context), vns) end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.getvalue(context.conditioned_values, vn) + return context.global_varinfo[][vn] end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) @@ -57,26 +50,30 @@ function is_target_varname(context::GibbsContext, vn::VarName) ) end +function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(is_target_varname, context), vns) +end + # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) + if is_target_varname(context, vn) + # Fall back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) + elseif has_conditioned_gibbs(context, vn) + # Short-circuits the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) # TODO(mhauru) Is the call to logpdf correct if context.context is not # DefaultContext? return value, logpdf(right, value), vi - elseif is_target_varname(context, vn) - # Fall back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( - DynamicPPL.SamplingContext( - DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) - ), + DynamicPPL.childcontext(context), + prior_sampler, right, vn, context.global_varinfo[], @@ -89,26 +86,27 @@ end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - # TODO(mhauru) Is the call to logpdf correct if context.context is not - # DefaultContext? - return value, logpdf(right, value), vi - elseif is_target_varname(context, vn) + if is_target_varname(context, vn) # Fall back to the default behavior. return DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) + elseif has_conditioned_gibbs(context, vn) + # Short-circuits the tilde assume if `vn` is present in `context`. + value = get_conditioned_gibbs(context, vn) + # TODO(mhauru) Is the call to logpdf correct if context.context is not + # DefaultContext? + return value, logpdf(right, value), vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( - DynamicPPL.SamplingContext( - rng, DynamicPPL.SampleFromPrior(), DynamicPPL.childcontext(context) - ), + rng, + DynamicPPL.childcontext(context), + prior_sampler, right, vn, context.global_varinfo[], @@ -137,31 +135,64 @@ function reconstruct_getvalue( end function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) + if is_target_varname(context, vns) + # Fall back to the default behavior. + return DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), right, left, vns, vi + ) + elseif has_conditioned_gibbs(context, vns) + # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) return value, broadcast_logpdf(right, value), vi + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() + value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), + prior_sampler, + right, + left, + vns, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume( - DynamicPPL.childcontext(context), right, left, vns, vi - ) end function DynamicPPL.dot_tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi ) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) + if is_target_varname(context, vns) + # Fall back to the default behavior. + return DynamicPPL.dot_tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi + ) + elseif has_conditioned_gibbs(context, vns) + # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) return value, broadcast_logpdf(right, value), vi + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + prior_sampler = DynamicPPL.SampleFromPrior() + value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( + rng, + DynamicPPL.childcontext(context), + prior_sampler, + right, + left, + vns, + context.global_varinfo[], + ) + context.global_varinfo[] = new_global_vi + return value, lp, vi end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi - ) end """ @@ -195,14 +226,7 @@ have in `varinfo`. function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) - # We want to condition all the variables in keys(varinfo) that are not subsumed by any - # of the target variables. - not_target_variables = filter( - x -> !(any(Iterators.map(vn -> subsumes(vn, x), target_variables))), keys(varinfo) - ) - vi_filtered = subset(varinfo, not_target_variables) - vals = DynamicPPL.OrderedDict(k => vi_filtered[k] for k in keys(vi_filtered)) - gibbs_context = GibbsContext(target_variables, vals, Ref(varinfo), model.context) + gibbs_context = GibbsContext(target_variables, Ref(varinfo), model.context) return DynamicPPL.contextualize(model, gibbs_context), gibbs_context end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index fc3ea6352d..c11f291628 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -15,7 +15,7 @@ using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff import Mooncake -using Test: @test, @test_deprecated, @testset +using Test: @test, @test_broken, @test_deprecated, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -135,8 +135,16 @@ has_dot_assume(::DynamicPPL.Model) = true Random.seed!(200) for alg in [ # The new syntax for specifying a sampler to run twice for one variable. - Gibbs(s => MH(), s => MH(), m => HMC(0.2, 4; adtype=adbackend)), - Gibbs(s => MH(), m => HMC(0.2, 4), m => HMC(0.2, 4); adtype=adbackend), + Gibbs( + @varname(s) => MH(), + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + ), + Gibbs( + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + ), ] chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain; atol=0.15) @@ -200,6 +208,27 @@ has_dot_assume(::DynamicPPL.Model) = true sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) end + @testset "dynamic model with dot tilde" begin + @model function dynamic_model_with_dot_tilde(num_zs=10) + z = Vector(undef, num_zs) + z .~ Exponential(1.0) + num_ms = Int(round(sum(z))) + m = Vector(undef, num_ms) + return m .~ Normal(1.0, 1.0) + end + model = dynamic_model_with_dot_tilde() + # TODO(mhauru) This is broken because of + # https://github.com/TuringLang/DynamicPPL.jl/issues/700. + @test_broken ( + sample( + model, + Gibbs(; z=NUTS(; adtype=adbackend), m=HMC(0.01, 4; adtype=adbackend)), + 100, + ); + true + ) + end + @testset "Demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) From af802dc9e6715d7dc00b151db8b6721a5eaad3f6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 16:43:15 +0000 Subject: [PATCH 17/70] Document and clean up GibbsContext --- src/mcmc/gibbs.jl | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 21d71d9b28..32ad22db3f 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -10,10 +10,31 @@ # rather than only for the "true" observations. # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. +""" + GibbsContext(target_varnames, global_varinfo, context) + +A context used in the implementation of the Turing.jl Gibbs sampler. + +There will be one `GibbsContext` for each iteration of a component sampler. +""" struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + """ + a collection of `VarName`s that are the ones the current component sampler is sampling. + For them, `GibbsContext` will just pass tilde_assume calls to its child context. + For other variables, their values will be fixed to the values they have in + `global_varinfo`. + """ target_varnames::VNs + """ + a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both + those fixed and those being sampled. We use a `Ref` because this field may need to be + updated if new variables are introduced. + """ global_varinfo::GVI + """ + the child context that tilde calls will eventually be passed onto. + """ context::Ctx end @@ -34,7 +55,14 @@ function has_conditioned_gibbs(context::GibbsContext, vn::VarName) return DynamicPPL.haskey(context.global_varinfo[], vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) + num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) + if (num_conditioned != 0) && (num_conditioned != length(vns)) + error( + "Some but not all of the variables in `vns` have been conditioned on. " * + "Having mixed conditioning like this is not supported in GibbsContext.", + ) + end + return num_conditioned > 0 end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) @@ -51,7 +79,14 @@ function is_target_varname(context::GibbsContext, vn::VarName) end function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(is_target_varname, context), vns) + num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) + if (num_target != 0) && (num_target != length(vns)) + error( + "Some but not all of the variables in `vns` are target variables. " * + "Having mixed targeting like this is not supported in GibbsContext.", + ) + end + return num_target > 0 end # Tilde pipeline From e58d93560794ec77cbac639a936542449ffe2a1a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Oct 2024 17:03:57 +0000 Subject: [PATCH 18/70] Code style and docs improvements to Gibbs --- src/mcmc/gibbs.jl | 79 +++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 32ad22db3f..9034e67080 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -95,7 +95,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuits the tilde assume if `vn` is present in `context`. + # Short-circuit the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) # TODO(mhauru) Is the call to logpdf correct if context.context is not # DefaultContext? @@ -105,10 +105,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( DynamicPPL.childcontext(context), - prior_sampler, + DynamicPPL.SampleFromPrior(), right, vn, context.global_varinfo[], @@ -118,30 +117,24 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) end end +# As above but with an RNG. function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) + # See comment in the above, rng-less version of this method for an explanation. if is_target_varname(context, vn) - # Fall back to the default behavior. return DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) elseif has_conditioned_gibbs(context, vn) - # Short-circuits the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) - # TODO(mhauru) Is the call to logpdf correct if context.context is not - # DefaultContext? + # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? return value, logpdf(right, value), vi else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. - prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), - prior_sampler, + DynamicPPL.SampleFromPrior(), right, vn, context.global_varinfo[], @@ -169,21 +162,18 @@ function reconstruct_getvalue( return reduce(hcat, x[2:end]; init=x[1]) end +# Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf. +# See comments there for more details. function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) if is_target_varname(context, vns) - # Fall back to the default behavior. return DynamicPPL.dot_tilde_assume( DynamicPPL.childcontext(context), right, left, vns, vi ) elseif has_conditioned_gibbs(context, vns) - # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? return value, broadcast_logpdf(right, value), vi else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( DynamicPPL.childcontext(context), @@ -198,23 +188,19 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi end end +# As above but with an RNG. function DynamicPPL.dot_tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi ) if is_target_varname(context, vns) - # Fall back to the default behavior. return DynamicPPL.dot_tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi ) elseif has_conditioned_gibbs(context, vns) - # Short-circuit the tilde assume if `vn` is present in `context`. value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? return value, broadcast_logpdf(right, value), vi else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( rng, @@ -230,21 +216,6 @@ function DynamicPPL.dot_tilde_assume( end end -""" - preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) - -Returns the preferred value type for a variable with the given `varinfo`. -""" -preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict -preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple -function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) - # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. - namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,typeof(identity)} - end - return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict -end - """ make_conditional(model, target_variables, varinfo) @@ -253,10 +224,15 @@ Return a new, conditioned model for a component of a Gibbs sampler. # Arguments - `model::DynamicPPL.Model`: The model to condition. - `target_variables::AbstractVector{<:VarName}`: The target variables of the component -sampler. These will _not_ conditioned. +sampler. These will _not_ be conditioned. - `varinfo::DynamicPPL.AbstractVarInfo`: Values for all variables in the model. All the values in `varinfo` but not in `target_variables` will be conditioned to the values they have in `varinfo`. + +# Returns +- A new model with the variables _not_ in `target_variables` conditioned. +- The `GibbsContext` object that will be used to condition the variables. This is necessary +because evaluation can mutate its `global_varinfo` field, which we need to access later. """ function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo @@ -360,7 +336,7 @@ function DynamicPPL.initialstep( varnames = alg.varnames samplers = alg.samplers - # 1. Run the model once to get the varnames present + initial values to condition on. + # Run the model once to get the varnames present + initial values to condition on. vi = DynamicPPL.VarInfo(rng, model) if initial_params !== nothing vi = DynamicPPL.unflatten(vi, initial_params) @@ -371,10 +347,13 @@ function DynamicPPL.initialstep( for (varnames_local, sampler_local) in zip(varnames, samplers) varnames_local = _maybevec(varnames_local) # Get the initial values for this component sampler. - vi_local = DynamicPPL.subset(vi, varnames_local) - initial_params_local = initial_params === nothing ? nothing : vi_local[:] + initial_params_local = if initial_params === nothing + nothing + else + DynamicPPL.subset(vi, varnames_local)[:] + end - # Construct the conditional model. + # Construct the conditioned model. model_local, context_local = make_conditional(model, varnames_local, vi) # Take initial step. @@ -397,7 +376,7 @@ function DynamicPPL.initialstep( # This merges in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) - # This merges the latest values for all the variables in the current sampler. + # This merges the new values for all the variables sampled by the current sampler. vi = merge(vi, new_vi_local) push!(states, new_state_local) end @@ -473,8 +452,8 @@ function recompute_logprob!!( return setlogp!!(state, vi_new.logp[]) end -# TODO(mhauru) Would really like to type constrain this to something like AbstractMCMCState -# if such a thing existed. +# TODO(mhauru) Would really like to type constrain the first argument to something like +# AbstractMCMCState if such a thing existed. function DynamicPPL.setlogp!!(state, logp) try new_vi = setlogp!!(state.vi, logp) @@ -496,11 +475,11 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) end # TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? -# The current list is a guess, but I think some might be unnecessary. +# The current list is a guess, and I think some are unnecessary. """ reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) -Return an updated state for a component sampler. +Return an updated state for a Gibbs component sampler. This takes into account changes caused by other Gibbs components. The default implementation is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to From b8c3dcdbac31651fd5a6bf9b43699d712f436cc4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 29 Oct 2024 10:29:36 +0000 Subject: [PATCH 19/70] Change how AdvancedHMC Gibbs state treats momenta --- src/mcmc/gibbs.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 9034e67080..ab3a827cdf 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -572,10 +572,18 @@ function reset_state!!( state.metric, DynamicPPL.LogDensityFunction(model) ) θ_new = varinfo[:] - # Set the momentum to zero, since we have no idea what it should be at the new parameter - # values. + # Modify the momentum to have the right number of elements, if the number of position + # variables has changed. Any new dimensions will be set to zero momentum. + # Note that there's no guarantee that any new variables are at the end of the parameter + # list, so we may end up mismatching momenta and parameters. This shouldn't be of + # consequence though, since the momentum will get resampled anyway. + # Frankly, we could probably just as well set the momenta to zero, but that made + # ForwardDiff crash for some reason I (mhauru) didn't bother to investigate. + momenta_old = state.transition.z.r + momenta_new = zero(θ_new) + momenta_new[1:length(momenta_old)] .= momenta_old return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, θ_new, zero(θ_new) + hamiltonian, θ_new, momenta_new ) end From da0b740e561f2a476b87fdc60b3454a6b2b16477 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 29 Oct 2024 11:31:03 +0000 Subject: [PATCH 20/70] Remove unnecessary invlinking --- src/mcmc/gibbs.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index ab3a827cdf..b2895d1e76 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -367,12 +367,6 @@ function DynamicPPL.initialstep( kwargs..., ) new_vi_local = varinfo(new_state_local) - # TODO(mhauru) Can we remove the invlinking? - new_vi_local = if DynamicPPL.istrans(new_vi_local) - DynamicPPL.invlink(new_vi_local, sampler_local, model_local) - else - new_vi_local - end # This merges in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) @@ -626,18 +620,9 @@ function gibbs_step_inner( state_local = states[index] varnames_local = _maybevec(varnames[index]) - # TODO(mhauru) Can we remove the invlinking? - vi = DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, vi) varinfo_local = subset(vi, varnames_local) - # TODO(mhauru) Can we remove the below, if get rid of all the invlinking? - # If the varinfo of the previous state from this sampler is linked, we should link the - # new varinfo too. - if DynamicPPL.istrans(varinfo(state_local)) - varinfo_local = DynamicPPL.link(varinfo_local, sampler_local, model_local) - end # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] From d984d2bbceec25e0e345c20c8eeb66fc96ebb8bc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 29 Oct 2024 13:50:43 +0000 Subject: [PATCH 21/70] Change how AdvancedHMC Gibbs state treats momenta, again --- src/mcmc/gibbs.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index b2895d1e76..8ee034e027 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -566,16 +566,13 @@ function reset_state!!( state.metric, DynamicPPL.LogDensityFunction(model) ) θ_new = varinfo[:] - # Modify the momentum to have the right number of elements, if the number of position - # variables has changed. Any new dimensions will be set to zero momentum. - # Note that there's no guarantee that any new variables are at the end of the parameter - # list, so we may end up mismatching momenta and parameters. This shouldn't be of - # consequence though, since the momentum will get resampled anyway. - # Frankly, we could probably just as well set the momenta to zero, but that made - # ForwardDiff crash for some reason I (mhauru) didn't bother to investigate. + # Set the momentum to some arbitrary value, making sure it has the right number of + # components. We could try to do something clever here to only reset momenta related to + # new variables, but it'll be resampled in the next iteration anyway. + # TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes + # ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug. momenta_old = state.transition.z.r - momenta_new = zero(θ_new) - momenta_new[1:length(momenta_old)] .= momenta_old + momenta_new = ones(eltype(momenta_old), length(θ_new)) return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( hamiltonian, θ_new, momenta_new ) From d52af52251d59b9f2bd08f5aa97c2e0f278fb562 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 1 Nov 2024 18:54:40 +0000 Subject: [PATCH 22/70] Use setparams!! rather than reset_state!! --- src/mcmc/gibbs.jl | 205 +++++++++++++++------------------------------- src/mcmc/hmc.jl | 5 +- 2 files changed, 70 insertions(+), 140 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 8ee034e027..d1b88f9463 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -50,9 +50,16 @@ function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) ) end +get_global_varinfo(context::GibbsContext) = context.global_varinfo[] + +function set_global_varinfo!(context::GibbsContext, new_global_varinfo) + context.global_varinfo[] = new_global_varinfo + return nothing +end + # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.haskey(context.global_varinfo[], vn) + return DynamicPPL.haskey(get_global_varinfo(context), vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) @@ -66,7 +73,7 @@ function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return context.global_varinfo[][vn] + return get_global_varinfo(context)[vn] end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) @@ -110,9 +117,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.SampleFromPrior(), right, vn, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -137,9 +144,9 @@ function DynamicPPL.tilde_assume( DynamicPPL.SampleFromPrior(), right, vn, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -181,9 +188,9 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi right, left, vns, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -209,9 +216,9 @@ function DynamicPPL.dot_tilde_assume( right, left, vns, - context.global_varinfo[], + get_global_varinfo(context), ) - context.global_varinfo[] = new_global_vi + set_global_varinfo!(context, new_global_vi) return value, lp, vi end end @@ -468,139 +475,71 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) return TuringState(setlogp!!(state.state, logp), logp) end -# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!? -# The current list is a guess, and I think some are unnecessary. -""" - reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous) - -Return an updated state for a Gibbs component sampler. - -This takes into account changes caused by other Gibbs components. The default implementation -is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to -do, a method should be implemented for the specific type of `state`. +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `setparams!!`. +function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector) + return DynamicPPL.unflatten(state, params) +end -# Arguments -- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not -sampled by this component sampler have been conditioned with a `GibbsContext`. -- `sampler::DynamicPPL.Sampler`: The current component sampler. -- `state`: The state of this component sampler from its previous iteration. -- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables -sampled by this component sampler. -- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain. -- `state_previous`: The state returned by the previous sampler. +function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo) + return params +end -# Returns -An updated state of the same type as `state`. It should have variables set to the values in -`varinfo`, and any other relevant updates done. -""" -function reset_state!!( - model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, + state::TuringState, + params::Union{AbstractVector,AbstractVarInfo}, ) - # In the fallback implementation we guess that `state` has a field called `vi` we can - # set. Fingers crossed! - try - return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo) - catch - error( - "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for reset_state!! for this type.", - ) - end + new_inner_state = AbstractMCMC.setparams!!(model, state.state, params) + return TuringState(new_inner_state, state.logdensity) end -# Some samplers use a VarInfo directly as the state. In that case, there's little to do in -# `reset_state!!`. -function reset_state!!( - model, - sampler, - state::AbstractVarInfo, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - return varinfo +# Unless some other treatment has been specified for this state type, just flatten the +# AbstractVarInfo. This method exists because some sampler types need to override this +# behavior. +function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo) + return AbstractMCMC.setparams!!(model, state, params[:]) end -function reset_state!!( - model, - sampler, - state::TuringState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo ) - new_inner_state = reset_state!!( - model, sampler, state.state, varinfo, sampler_previous, state_previous + θ_new = params[:] + hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new)) + + # Update the parameter values in `state.z`. + # TODO: Avoid mutation + z = state.z + resize!(z.θ, length(θ_new)) + z.θ .= θ_new + return HMCState( + params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler ) - return TuringState(new_inner_state, state.logdensity) end -function reset_state!!( - model, - sampler, - state::HMCState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVector ) - θ_new = varinfo[:] - hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new)) + θ_new = params + vi = DynamicPPL.unflatten(state.vi, params) + hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new)) # Update the parameter values in `state.z`. # TODO: Avoid mutation z = state.z resize!(z.θ, length(θ_new)) z.θ .= θ_new - return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) + return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler) end -function reset_state!!( - model, - sampler, - state::AdvancedHMC.HMCState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo ) - hamiltonian = AdvancedHMC.Hamiltonian( - state.metric, DynamicPPL.LogDensityFunction(model) - ) - θ_new = varinfo[:] - # Set the momentum to some arbitrary value, making sure it has the right number of - # components. We could try to do something clever here to only reset momenta related to - # new variables, but it'll be resampled in the next iteration anyway. - # TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes - # ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug. - momenta_old = state.transition.z.r - momenta_new = ones(eltype(momenta_old), length(θ_new)) - return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, θ_new, momenta_new - ) + return PGState(params, state.rng) end -function reset_state!!( - model, - sampler, - state::AdvancedMH.Transition, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - # TODO(mhauru) Setting the last argument like this seems a bit suspect, since the - # current values for the parameters might not have come from this sampler at all. - # I don't see a better way though. - return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted) -end - -function reset_state!!( - model, - sampler, - state::PGState, - varinfo::AbstractVarInfo, - sampler_previous, - state_previous, -) - return PGState(varinfo, state.rng) +function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector) + return PGState(DynamicPPL.unflatten(state.vi, params), state.rng) end function gibbs_step_inner( @@ -609,7 +548,7 @@ function gibbs_step_inner( varnames, samplers, states, - vi, + global_vi, index; kwargs..., ) @@ -618,8 +557,8 @@ function gibbs_step_inner( varnames_local = _maybevec(varnames[index]) # Construct the conditional model and the varinfo that this sampler should use. - model_local, context_local = make_conditional(model, varnames_local, vi) - varinfo_local = subset(vi, varnames_local) + model_local, context_local = make_conditional(model, varnames_local, global_vi) + varinfo_local = subset(global_vi, varnames_local) # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] @@ -627,14 +566,7 @@ function gibbs_step_inner( # Set the state of the current sampler, accounting for any changes made by other # samplers. - state_local = reset_state!!( - model_local, - sampler_local, - state_local, - varinfo_local, - sampler_previous, - state_previous, - ) + state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) @@ -647,11 +579,8 @@ function gibbs_step_inner( ) new_vi_local = varinfo(new_state_local) - # This merges in any new variables that were introduced during the step, but that - # were not in the domain of the current sampler. - new_vi = merge(vi, context_local.global_varinfo[]) - # This merges the latest values for all the variables in the current sampler. - new_vi = merge(new_vi, new_vi_local) - new_vi = setlogp!!(new_vi, new_vi_local.logp[]) - return new_vi, new_state_local + # Merge the latest values for all the variables in the current sampler. + new_global_vi = merge(get_global_varinfo(context_local), new_vi_local) + new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + return new_global_vi, new_state_local end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d01ef274a7..ab018e787a 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -15,6 +15,7 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt + sampler::Sampler{<:Hamiltonian} end ### @@ -229,7 +230,7 @@ function DynamicPPL.initialstep( end transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) + state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl) return transition, state end @@ -275,7 +276,7 @@ function AbstractMCMC.step( # Compute next transition and state. transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl) return transition, newstate end From 508ac61ad52bfdb2bdd7d4621e2ae6e62cafe3e2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 4 Nov 2024 15:54:00 +0000 Subject: [PATCH 23/70] Don't overload setparams\!\! with VarInfo --- src/mcmc/gibbs.jl | 77 ++++++++++++++++++++--------------------------- src/mcmc/hmc.jl | 5 ++- 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d1b88f9463..9e1a705438 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -475,73 +475,58 @@ function DynamicPPL.setlogp!!(state::TuringState, logp) return TuringState(setlogp!!(state.state, logp), logp) end -# Some samplers use a VarInfo directly as the state. In that case, there's little to do in -# `setparams!!`. -function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector) - return DynamicPPL.unflatten(state, params) -end - -function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo) - return params -end +""" + setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, - state::TuringState, - params::Union{AbstractVector,AbstractVarInfo}, -) - new_inner_state = AbstractMCMC.setparams!!(model, state.state, params) - return TuringState(new_inner_state, state.logdensity) -end +A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an +`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to +`AbstractMCMC.setparams!!(model, state, params[:])`. -# Unless some other treatment has been specified for this state type, just flatten the -# AbstractVarInfo. This method exists because some sampler types need to override this -# behavior. -function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo) +`model` is typically a `DynamicPPL.Model`, but can also be e.g. an +`AbstractMCMC.LogDensityModel`. +""" +function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) return AbstractMCMC.setparams!!(model, state, params[:]) end -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo +# Some samplers use a VarInfo directly as the state. In that case, there's little to do in +# `setparams_varinfo!!`. +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::VarInfo, params::AbstractVarInfo ) - θ_new = params[:] - hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new)) + return params +end - # Update the parameter values in `state.z`. - # TODO: Avoid mutation - z = state.z - resize!(z.θ, length(θ_new)) - z.θ .= θ_new - return HMCState( - params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::TuringState, params::AbstractVarInfo +) + logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) + new_inner_state = setparams_varinfo!!( + AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params ) + return TuringState(new_inner_state, logdensity) end -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, state::HMCState, params::AbstractVector +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::HMCState, params::AbstractVarInfo ) - θ_new = params - vi = DynamicPPL.unflatten(state.vi, params) - hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new)) + θ_new = params[:] + hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) # Update the parameter values in `state.z`. # TODO: Avoid mutation z = state.z resize!(z.θ, length(θ_new)) z.θ .= θ_new - return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler) + return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor) end -function AbstractMCMC.setparams!!( - model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler, state::PGState, params::AbstractVarInfo ) return PGState(params, state.rng) end -function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector) - return PGState(DynamicPPL.unflatten(state.vi, params), state.rng) -end - function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -566,7 +551,9 @@ function gibbs_step_inner( # Set the state of the current sampler, accounting for any changes made by other # samplers. - state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local) + state_local = setparams_varinfo!!( + model_local, sampler_local, state_local, varinfo_local + ) if gibbs_requires_recompute_logprob( model_local, sampler_local, sampler_previous, state_local, state_previous ) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index ab018e787a..d01ef274a7 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -15,7 +15,6 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt - sampler::Sampler{<:Hamiltonian} end ### @@ -230,7 +229,7 @@ function DynamicPPL.initialstep( end transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl) + state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) return transition, state end @@ -276,7 +275,7 @@ function AbstractMCMC.step( # Compute next transition and state. transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) return transition, newstate end From 6ff7c59aae9ff321a0e31ba809b66ff7a7788df2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 5 Nov 2024 10:15:46 +0000 Subject: [PATCH 24/70] A fix for ESS in Gibbs --- src/mcmc/gibbs.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 9e1a705438..e13b71f8bc 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -374,6 +374,12 @@ function DynamicPPL.initialstep( kwargs..., ) new_vi_local = varinfo(new_state_local) + # TODO(mhauru) Remove the below loop once samplers no longer depend on selectors. + # For some reason not having this in place was causing trouble for ESS, but not for + # other samplers. I didn't get to the bottom of it. + for vn in keys(new_vi_local) + DynamicPPL.setgid!(new_vi_local, sampler_local.selector, vn) + end # This merges in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) @@ -544,6 +550,12 @@ function gibbs_step_inner( # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, global_vi) varinfo_local = subset(global_vi, varnames_local) + # TODO(mhauru) Remove the below loop once samplers no longer depend on selectors. + # For some reason not having this in place was causing trouble for ESS, but not for + # other samplers. I didn't get to the bottom of it. + for vn in keys(varinfo_local) + DynamicPPL.setgid!(varinfo_local, sampler_local.selector, vn) + end # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] From 934c03efd281c092c0c69e54b49325a3ba21f5b6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 5 Nov 2024 16:02:16 +0000 Subject: [PATCH 25/70] Remove recompute_logprob!! --- src/mcmc/abstractmcmc.jl | 45 ------------------ src/mcmc/gibbs.jl | 100 ++++++--------------------------------- test/mcmc/gibbs.jl | 4 +- 3 files changed, 18 insertions(+), 131 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index c6ca61a9c4..8dfee52b4a 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -56,51 +56,6 @@ function setvarinfo( ) end -""" - recompute_logprob!!(rng, model, sampler, state) - -Recompute the log-probability of the `model` based on the given `state` and return the resulting state. -""" -function recompute_logprob!!( - rng::Random.AbstractRNG, # TODO: Do we need the `rng` here? - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:ExternalSampler}, - state, # TODO(mhauru) Could we type constrain this to TuringState? -) - # Re-using the log-density function from the `state` and updating only the `model` field, - # since the `model` might now contain different conditioning values. - f = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) - # Recompute the log-probability with the new `model`. - state_inner = recompute_logprob!!( - rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state - ) - return state_to_turing(f, state_inner) -end - -function recompute_logprob!!( - rng::Random.AbstractRNG, - model::AbstractMCMC.LogDensityModel, - sampler::AdvancedHMC.AbstractHMCSampler, - state::AdvancedHMC.HMCState, -) - # Construct hamiltionian. - hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) - # Re-compute the log-probability and gradient. - return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, state.transition.z.θ, state.transition.z.r - ) -end - -function recompute_logprob!!( - rng::Random.AbstractRNG, - model::AbstractMCMC.LogDensityModel, - sampler::AdvancedMH.MetropolisHastings, - state::AdvancedMH.Transition, -) - logdensity = model.logdensity - return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params) -end - # TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index e13b71f8bc..d44184da22 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -407,80 +407,17 @@ function AbstractMCMC.step( # TODO: move this into a recursive function so we can unroll when reasonable? for index in 1:length(samplers) # Take the inner step. + sampler_local = samplers[index] + state_local = states[index] + varnames_local = _maybevec(varnames[index]) vi, new_state_local = gibbs_step_inner( - rng, model, varnames, samplers, states, vi, index; kwargs... + rng, model, varnames_local, sampler_local, state_local, vi; kwargs... ) states = Accessors.setindex(states, new_state_local, index) end return Transition(model, vi), GibbsState(vi, states) end -# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) - # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide - # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact - # same `selector` as before but now with `rerun` set to `true` if needed. - return Accessors.@set sampler.selector.rerun = true -end - -# Interface we need a sampler to implement to work as a component in a Gibbs sampler. -""" - gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - -Check if the log-probability of the destination model needs to be recomputed. - -Defaults to `true` -""" -function gibbs_requires_recompute_logprob( - model_dst, sampler_dst, sampler_src, state_dst, state_src -) - return true -end - -# TODO: Remove `rng`? -function recompute_logprob!!( - rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, state -) - vi = varinfo(state) - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler) - # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed - # `varinfo`, even if `varinfo` was linked. - vi_new = last( - DynamicPPL.evaluate!!( - model, - vi, - # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. - DynamicPPL.SamplingContext(rng, sampler_rerun), - ) - ) - return setlogp!!(state, vi_new.logp[]) -end - -# TODO(mhauru) Would really like to type constrain the first argument to something like -# AbstractMCMCState if such a thing existed. -function DynamicPPL.setlogp!!(state, logp) - try - new_vi = setlogp!!(state.vi, logp) - if new_vi !== state.vi - return Accessors.set(state, Accessors.PropertyLens{:vi}(), new_vi) - else - return state - end - catch - error( - "Unable to set `state.vi` for a $(typeof(state)). " * - "Consider writing a method for `setlogp!!` for this type.", - ) - end -end - -function DynamicPPL.setlogp!!(state::TuringState, logp) - return TuringState(setlogp!!(state.state, logp), logp) -end - """ setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) @@ -536,17 +473,12 @@ end function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, - varnames, - samplers, - states, - global_vi, - index; + varnames_local, + sampler_local, + state_local, + global_vi; kwargs..., ) - sampler_local = samplers[index] - state_local = states[index] - varnames_local = _maybevec(varnames[index]) - # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, global_vi) varinfo_local = subset(global_vi, varnames_local) @@ -557,20 +489,18 @@ function gibbs_step_inner( DynamicPPL.setgid!(varinfo_local, sampler_local.selector, vn) end - # Extract the previous sampler and state. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - state_previous = states[index == 1 ? length(states) : index - 1] - + # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not + # sampled by other samplers, we don't need to `setparams`, but could rather simply + # recompute the log probability. More over, in some cases the recomputation could also + # be avoided, if e.g. the previous sampler has done all the necessary work already. + # However, we've judged that doing any caching or other tricks to avoid this now would + # be premature optimization. In most use cases of Gibbs a single model call here is not + # going to be a significant expense anyway. # Set the state of the current sampler, accounting for any changes made by other # samplers. state_local = setparams_varinfo!!( model_local, sampler_local, state_local, varinfo_local ) - if gibbs_requires_recompute_logprob( - model_local, sampler_local, sampler_previous, state_local, state_previous - ) - state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) - end # Take a step with the local sampler. new_state_local = last( diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index c11f291628..65f192dc71 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -183,6 +183,8 @@ has_dot_assume(::DynamicPPL.Model) = true end @testset "dynamic model" begin + # TODO(mhauru) We should check that the results of the sampling are correct. + # Currently we just check that this doesn't crash. @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} N = length(y) rpm = DirichletProcess(alpha) @@ -204,7 +206,7 @@ has_dot_assume(::DynamicPPL.Model) = true end model = imm(Random.randn(100), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 - # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100); + # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) end From 93010f76f9f1be36318fa07ad191f46539cc4618 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 10:17:36 +0000 Subject: [PATCH 26/70] Fix setparams_varinfo!! for MH --- src/mcmc/gibbs.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d44184da22..ad416d6891 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -432,12 +432,12 @@ function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) return AbstractMCMC.setparams!!(model, state, params[:]) end -# Some samplers use a VarInfo directly as the state. In that case, there's little to do in -# `setparams_varinfo!!`. function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler, state::VarInfo, params::AbstractVarInfo + model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::VarInfo, params::AbstractVarInfo ) - return params + # The state is already a VarInfo, so we can just return `params`, but first we need to + # update its logprob. + return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.DefaultContext())) end function setparams_varinfo!!( From 1576d21e98c12286a4a84c74bc5f22836a703be9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 10:18:58 +0000 Subject: [PATCH 27/70] Stop hard coding the leafcontext for MH setparams_varinfo!! --- src/mcmc/gibbs.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index ad416d6891..48794a905e 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -437,7 +437,8 @@ function setparams_varinfo!!( ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.DefaultContext())) + # TODO(mhauru) Is this the right context to use? + return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) end function setparams_varinfo!!( From 96674f8ce40e2f96addaaa24f1d40544d5dfa5ff Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 10:50:41 +0000 Subject: [PATCH 28/70] Fix setparams_varinfo!! for ESS --- src/mcmc/gibbs.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 48794a905e..d7d3020098 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -441,6 +441,18 @@ function setparams_varinfo!!( return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) end +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::Sampler{<:ESS}, + state::VarInfo, + params::AbstractVarInfo, +) + # The state is already a VarInfo, so we can just return `params`, but first we need to + # update its logprob. + # TODO(mhauru) Is this the right context to use? + return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) +end + function setparams_varinfo!!( model::DynamicPPL.Model, sampler::Sampler, state::TuringState, params::AbstractVarInfo ) From 69f14ceaeacd6944f64768c2b3346c8941baef74 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 11:19:48 +0000 Subject: [PATCH 29/70] Fix the context used by setparams_varinfo!! ESS --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d7d3020098..8e44a4a8de 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -450,7 +450,7 @@ function setparams_varinfo!!( # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. # TODO(mhauru) Is this the right context to use? - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) + return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.LikelihoodContext())) end function setparams_varinfo!!( From 00bcdd830e5d42f9d0c693c1da2058c0e05994cf Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 10:54:14 +0000 Subject: [PATCH 30/70] Add GibbsContext type stability tests --- test/Project.toml | 1 + test/mcmc/gibbs.jl | 79 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index ce01d6a210..9ccbe5d770 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 65f192dc71..5f8ef36371 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -8,6 +8,7 @@ using ..NumericalTests: check_numerical, two_sample_test import ..ADUtils +import Combinatorics using Distributions: InverseGamma, Normal using Distributions: sample using DynamicPPL: DynamicPPL @@ -15,7 +16,7 @@ using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff import Mooncake -using Test: @test, @test_broken, @test_deprecated, @testset +using Test: @inferred, @test, @test_broken, @test_deprecated, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -42,6 +43,82 @@ const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::DynamicPPL.Model) = true +@testset "GibbsContext" begin + @testset "type stability" begin + # A test model that has multiple features in one package: + # Floats, Ints, arguments, observations, loops, dot_tildes. + @model function test_model(obs1, obs2, num_vars, mean) + variance ~ Exponential(2) + z = Vector{Float64}(undef, num_vars) + z .~ truncated(Normal(mean, variance); lower=1) + y = Vector{Int64}(undef, num_vars) + for i in 1:num_vars + y[i] ~ Poisson(Int(round(z[i]))) + end + s = sum(y) - sum(z) + obs1 ~ Normal(s, 1) + return obs2 ~ Poisson(y[3]) + end + + model = test_model(1.2, 2, 10, 2.5) + all_varnames = DynamicPPL.VarName[@varname(variance), @varname(z), @varname(y)] + # All combinations of elements in all_varnames. + target_vn_combinations = Iterators.flatten( + Iterators.map( + n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) + ), + ) + for typed in (true, false) + for target_vns in target_vn_combinations + global_varinfo = + typed ? DynamicPPL.VarInfo(model) : DynamicPPL.untyped_varinfo(model) + target_vns = collect(target_vns) + local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) + ctx = Turing.Inference.GibbsContext( + target_vns, Ref(global_varinfo), Turing.DefaultContext() + ) + + # Check that the correct varnames are conditioned, and that getting their + # values is type stable. + for k in keys(global_varinfo) + is_target = any( + Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns) + ) + @test Turing.Inference.is_target_varname(ctx, k) == is_target + if !is_target && typed + @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) + end + end + + # Check the type stability also in the dot_tilde pipeline. + for k in all_varnames + # The map(identity, ...) part is there to concretise the eltype. + subkeys = map( + identity, + filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)), + ) + is_target = (k in target_vns) + @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target + if !is_target && typed + @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) + end + end + + # Check that evaluate!! and the result it returns are type stable. + conditioned_model = DynamicPPL.contextualize(model, ctx) + _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( + conditioned_model, local_varinfo + ) + if typed + for k in keys(post_eval_varinfo) + @inferred post_eval_varinfo[k] + end + end + end + end + end +end + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "Deprecated Gibbs constructors" begin N = 10 From dc5b0cfd566f613e2e1bbb1954115053b9caa6d4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 14:14:37 +0000 Subject: [PATCH 31/70] Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde --- src/mcmc/gibbs.jl | 12 ++++++++---- test/mcmc/gibbs.jl | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 8e44a4a8de..23cdd9b190 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -16,11 +16,14 @@ A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. + +# Fields +$(DocStringExtensions.FIELDS) """ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext """ - a collection of `VarName`s that are the ones the current component sampler is sampling. + a collection of `VarName`s that the current component sampler is sampling. For them, `GibbsContext` will just pass tilde_assume calls to its child context. For other variables, their values will be fixed to the values they have in `global_varinfo`. @@ -380,10 +383,10 @@ function DynamicPPL.initialstep( for vn in keys(new_vi_local) DynamicPPL.setgid!(new_vi_local, sampler_local.selector, vn) end - # This merges in any new variables that were introduced during the step, but that + # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) - # This merges the new values for all the variables sampled by the current sampler. + # Merge the new values for all the variables sampled by the current sampler. vi = merge(vi, new_vi_local) push!(states, new_state_local) end @@ -437,7 +440,8 @@ function setparams_varinfo!!( ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. - # TODO(mhauru) Is this the right context to use? + # NOTE: Using `leafcontext(model.context)` here is a no-op, as it will be concatenated + # with `model.context` before hitting `model.f`. return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 5f8ef36371..fecb9adec6 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -79,7 +79,7 @@ has_dot_assume(::DynamicPPL.Model) = true ) # Check that the correct varnames are conditioned, and that getting their - # values is type stable. + # values is type stable when the varinfo is. for k in keys(global_varinfo) is_target = any( Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns) From fa366d3b24f6820e9e0f2a08bc87e7e593b44642 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 14:32:20 +0000 Subject: [PATCH 32/70] Add clarifying comment --- src/mcmc/gibbs.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 23cdd9b190..b0d9e8a8c3 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -453,7 +453,8 @@ function setparams_varinfo!!( ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. - # TODO(mhauru) Is this the right context to use? + # Note the use of LikelihoodContext, regardless of what context `model` has. This is + # specific to ESS as a sampler. return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.LikelihoodContext())) end From 4e998704f6892aefe94337eac5cb20f4055de1a6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 14:32:42 +0000 Subject: [PATCH 33/70] Add setparams_varinfo!! type bounds --- src/mcmc/gibbs.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index b0d9e8a8c3..fb4f4d5e76 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -459,7 +459,10 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler, state::TuringState, params::AbstractVarInfo + model::DynamicPPL.Model, + sampler::Sampler{<:ExternalSampler}, + state::TuringState, + params::AbstractVarInfo, ) logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) new_inner_state = setparams_varinfo!!( @@ -469,7 +472,10 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler, state::HMCState, params::AbstractVarInfo + model::DynamicPPL.Model, + sampler::Sampler{<:Hamiltonian}, + state::HMCState, + params::AbstractVarInfo, ) θ_new = params[:] hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) @@ -483,7 +489,7 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler, state::PGState, params::AbstractVarInfo + model::DynamicPPL.Model, sampler::Sampler{<:PG}, state::PGState, params::AbstractVarInfo ) return PGState(params, state.rng) end From da7342e67d36920490764a591e74f655521eda20 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 14:42:48 +0000 Subject: [PATCH 34/70] Fix an import --- src/mcmc/Inference.jl | 2 +- src/mcmc/gibbs.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index ca1f45e102..cd1b78be01 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -35,7 +35,7 @@ using StatsFuns: logsumexp using Random: AbstractRNG using DynamicPPL using AbstractMCMC: AbstractModel, AbstractSampler -using DocStringExtensions: TYPEDEF, TYPEDFIELDS +using DocStringExtensions: FIELDS, TYPEDEF, TYPEDFIELDS using DataStructures: OrderedSet using Accessors: Accessors diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index fb4f4d5e76..310650159c 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -18,7 +18,7 @@ A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. # Fields -$(DocStringExtensions.FIELDS) +$(FIELDS) """ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext From 14decaf063835fd1f515380eba83bce0e7db0b0a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 14:45:39 +0000 Subject: [PATCH 35/70] Style improvement --- src/mcmc/gibbs.jl | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 310650159c..a9817e25b1 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -101,15 +101,15 @@ end # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - if is_target_varname(context, vn) + return if is_target_varname(context, vn) # Fall back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) + DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. value = get_conditioned_gibbs(context, vn) # TODO(mhauru) Is the call to logpdf correct if context.context is not # DefaultContext? - return value, logpdf(right, value), vi + value, logpdf(right, value), vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add @@ -123,7 +123,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - return value, lp, vi + value, lp, vi end end @@ -132,14 +132,14 @@ function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) # See comment in the above, rng-less version of this method for an explanation. - if is_target_varname(context, vn) - return DynamicPPL.tilde_assume( + return if is_target_varname(context, vn) + DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) elseif has_conditioned_gibbs(context, vn) value = get_conditioned_gibbs(context, vn) # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? - return value, logpdf(right, value), vi + value, logpdf(right, value), vi else value, lp, new_global_vi = DynamicPPL.tilde_assume( rng, @@ -150,7 +150,7 @@ function DynamicPPL.tilde_assume( get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - return value, lp, vi + value, lp, vi end end @@ -175,14 +175,12 @@ end # Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf. # See comments there for more details. function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - if is_target_varname(context, vns) - return DynamicPPL.dot_tilde_assume( - DynamicPPL.childcontext(context), right, left, vns, vi - ) + return if is_target_varname(context, vns) + DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) elseif has_conditioned_gibbs(context, vns) value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? - return value, broadcast_logpdf(right, value), vi + value, broadcast_logpdf(right, value), vi else prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( @@ -194,7 +192,7 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - return value, lp, vi + value, lp, vi end end @@ -202,14 +200,14 @@ end function DynamicPPL.dot_tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi ) - if is_target_varname(context, vns) - return DynamicPPL.dot_tilde_assume( + return if is_target_varname(context, vns) + DynamicPPL.dot_tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi ) elseif has_conditioned_gibbs(context, vns) value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? - return value, broadcast_logpdf(right, value), vi + value, broadcast_logpdf(right, value), vi else prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( @@ -222,7 +220,7 @@ function DynamicPPL.dot_tilde_assume( get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - return value, lp, vi + value, lp, vi end end From a7317e8380c2efc7e0604b6460c21e127e3d5f81 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 15:24:05 +0000 Subject: [PATCH 36/70] Improve GibbsContext type stability test --- test/mcmc/gibbs.jl | 77 +++++++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index fecb9adec6..0ee75efb65 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -57,7 +57,8 @@ has_dot_assume(::DynamicPPL.Model) = true end s = sum(y) - sum(z) obs1 ~ Normal(s, 1) - return obs2 ~ Poisson(y[3]) + obs2 ~ Poisson(y[3]) + return obs1, obs2, variance, z, y, s end model = test_model(1.2, 2, 10, 2.5) @@ -68,53 +69,45 @@ has_dot_assume(::DynamicPPL.Model) = true n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) ), ) - for typed in (true, false) - for target_vns in target_vn_combinations - global_varinfo = - typed ? DynamicPPL.VarInfo(model) : DynamicPPL.untyped_varinfo(model) - target_vns = collect(target_vns) - local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) - ctx = Turing.Inference.GibbsContext( - target_vns, Ref(global_varinfo), Turing.DefaultContext() - ) - - # Check that the correct varnames are conditioned, and that getting their - # values is type stable when the varinfo is. - for k in keys(global_varinfo) - is_target = any( - Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns) - ) - @test Turing.Inference.is_target_varname(ctx, k) == is_target - if !is_target && typed - @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) - end - end + for target_vns in target_vn_combinations + global_varinfo = DynamicPPL.VarInfo(model) + target_vns = collect(target_vns) + local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) + ctx = Turing.Inference.GibbsContext( + target_vns, Ref(global_varinfo), Turing.DefaultContext() + ) - # Check the type stability also in the dot_tilde pipeline. - for k in all_varnames - # The map(identity, ...) part is there to concretise the eltype. - subkeys = map( - identity, - filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)), - ) - is_target = (k in target_vns) - @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target - if !is_target && typed - @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) - end + # Check that the correct varnames are conditioned, and that getting their + # values is type stable when the varinfo is. + for k in keys(global_varinfo) + is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) + @test Turing.Inference.is_target_varname(ctx, k) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) end + end - # Check that evaluate!! and the result it returns are type stable. - conditioned_model = DynamicPPL.contextualize(model, ctx) - _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( - conditioned_model, local_varinfo + # Check the type stability also in the dot_tilde pipeline. + for k in all_varnames + # The map(identity, ...) part is there to concretise the eltype. + subkeys = map( + identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) ) - if typed - for k in keys(post_eval_varinfo) - @inferred post_eval_varinfo[k] - end + is_target = (k in target_vns) + @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) end end + + # Check that evaluate!! and the result it returns are type stable. + conditioned_model = DynamicPPL.contextualize(model, ctx) + _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( + conditioned_model, local_varinfo + ) + for k in keys(post_eval_varinfo) + @inferred post_eval_varinfo[k] + end end end end From 7e75caa0952a499f5c24c5b78b41ef13d45ebb89 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 15:57:32 +0000 Subject: [PATCH 37/70] Add comment to constructor tests --- test/mcmc/gibbs.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 0ee75efb65..b61347ed49 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -140,14 +140,22 @@ end end @testset "Gibbs constructors" begin + # Create Gibbs samplers with various configurations and ways of passing the + # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. N = 10 - s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5, :s, :m; adtype=adbackend)) + # Two variables being sampled by one sampler. + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) + # One variable per sampler, using the keyword arg interface. s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + # As above but using a Dict of VarNames. s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + # As above but different samplers. s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + # Multiple instnaces of the same sampler. This implements running, in this case, + # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. s8 = begin hmc = HMC(0.1, 5; adtype=adbackend) pg = PG(10) From 16ba3f82d28c0553958dfc4aa44341f464284a3c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 16:02:59 +0000 Subject: [PATCH 38/70] Fix a Gibbs test --- test/mcmc/gibbs.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index b61347ed49..e416c2e12f 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -289,22 +289,20 @@ end end @testset "dynamic model with dot tilde" begin - @model function dynamic_model_with_dot_tilde(num_zs=10) - z = Vector(undef, num_zs) - z .~ Exponential(1.0) - num_ms = Int(round(sum(z))) - m = Vector(undef, num_ms) + @model function dynamic_model_with_dot_tilde( + num_zs=10, ::Type{M}=Vector{Float64} + ) where {M} + z = M(undef, num_zs) + z .~ Poisson(1.0) + num_ms = sum(z) + m = M(undef, num_ms) return m .~ Normal(1.0, 1.0) end model = dynamic_model_with_dot_tilde() # TODO(mhauru) This is broken because of # https://github.com/TuringLang/DynamicPPL.jl/issues/700. @test_broken ( - sample( - model, - Gibbs(; z=NUTS(; adtype=adbackend), m=HMC(0.01, 4; adtype=adbackend)), - 100, - ); + sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); true ) end From d35cbce40fe57385ab8af4aaeeac79e43d7651b3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 12 Nov 2024 16:09:53 +0000 Subject: [PATCH 39/70] Document the methods of `varinfo` better --- src/mcmc/abstractmcmc.jl | 4 ++-- src/mcmc/gibbs.jl | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 8dfee52b4a..a815a1bc16 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -28,8 +28,8 @@ function varinfo(state::TuringState) return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) end varinfo(state::AbstractVarInfo) = state -# TODO(mhauru) Could we have a type bound on the argument below, for documentation purposes? -varinfo(state) = state.vi +varinfo(state::HMCState) = state.vi +varinfo(state::PGState) = state.vi # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index a9817e25b1..256aa315a0 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -332,6 +332,8 @@ _maybevec(x) = vec(x) # assume it's iterable _maybevec(x::Tuple) = [x...] _maybevec(x::VarName) = [x] +varinfo(state::GibbsState) = state.vi + function DynamicPPL.initialstep( rng::Random.AbstractRNG, model::DynamicPPL.Model, From 5fa94584d502997a05cd50700a2326a2778d8419 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 13 Nov 2024 11:18:29 +0000 Subject: [PATCH 40/70] Check whether a sampler is a valid Gibbs component --- src/mcmc/gibbs.jl | 33 +++++++++++++++++++++++++++++++++ test/mcmc/gibbs.jl | 19 ++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 256aa315a0..cd69cc7f17 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,3 +1,23 @@ +""" + isgibbscomponent(alg::Union{InferenceAlgorithm, AbstractMCMC.AbstractSampler}) + +Return a boolean indicating whether `alg` is a valid component for a Gibbs sampler. + +Defaults to `false` if no method has been defined for a particular algorithm type. +""" +isgibbscomponent(::InferenceAlgorithm) = false +isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) +isgibbscomponent(spl::Sampler) = isgibbscomponent(spl.alg) + +isgibbscomponent(::ESS) = true +isgibbscomponent(::HMC) = true +isgibbscomponent(::HMCDA) = true +isgibbscomponent(::NUTS) = true +isgibbscomponent(::MH) = true +isgibbscomponent(::PG) = true +isgibbscomponent(::AdvancedHMC.HMC) = true +isgibbscomponent(::AdvancedMH.MetropolisHastings) = true + # Basically like a `DynamicPPL.FixedContext` but # 1. Hijacks the tilde pipeline to fix variables. # 2. Computes the log-probability of the fixed variables. @@ -267,6 +287,19 @@ struct Gibbs{V,A} <: InferenceAlgorithm varnames::V "samplers for each entry in `varnames`" samplers::A + + function Gibbs(varnames, samplers) + if length(varnames) != length(samplers) + throw(ArgumentError("Number of varnames and samplers must match.")) + end + for spl in samplers + if !isgibbscomponent(spl) + msg = "All samplers must be valid Gibbs components, $(spl) is not." + throw(ArgumentError(msg)) + end + end + return new{typeof(varnames),typeof(samplers)}(varnames, samplers) + end end # NamedTuple diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index e416c2e12f..c8441c0517 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -16,7 +16,7 @@ using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff import Mooncake -using Test: @inferred, @test, @test_broken, @test_deprecated, @testset +using Test: @inferred, @test, @test_broken, @test_deprecated, @test_throws, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -112,6 +112,23 @@ has_dot_assume(::DynamicPPL.Model) = true end end +@testset "Invalid Gibbs constructor" begin + # More samplers than varnames or vice versa + @test_throws ArgumentError Gibbs((@varname(s), @varname(m)), (NUTS(), NUTS(), NUTS())) + @test_throws ArgumentError Gibbs( + (@varname(s), @varname(m), @varname(x)), (NUTS(), NUTS()) + ) + # Invalid samplers + @test_throws ArgumentError Gibbs(@varname(s) => IS()) + @test_throws ArgumentError Gibbs(@varname(s) => Emcee(10, 2.0)) + @test_throws ArgumentError Gibbs( + @varname(s) => SGHMC(; learning_rate=0.01, momentum_decay=0.1) + ) + @test_throws ArgumentError Gibbs( + @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) + ) +end + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "Deprecated Gibbs constructors" begin N = 10 From 80cc62f5d6a7118bf06c97df2951c1c9df9c9e59 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 13 Nov 2024 14:23:09 +0000 Subject: [PATCH 41/70] Move varinfo methods where they belong --- src/mcmc/Inference.jl | 13 +++++++++++++ src/mcmc/abstractmcmc.jl | 2 -- src/mcmc/emcee.jl | 2 ++ src/mcmc/ess.jl | 2 ++ src/mcmc/gibbs.jl | 21 +++++---------------- src/mcmc/hmc.jl | 16 ++++++++++++++++ src/mcmc/is.jl | 2 ++ src/mcmc/mh.jl | 2 ++ src/mcmc/particle_mcmc.jl | 6 ++++++ src/mcmc/sghmc.jl | 8 ++++++++ 10 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index cd1b78be01..6765e7b72c 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -92,6 +92,14 @@ abstract type Hamiltonian <: InferenceAlgorithm end abstract type StaticHamiltonian <: Hamiltonian end abstract type AdaptiveHamiltonian <: Hamiltonian end +# TODO(mhauru) Remove the below function once all the space/Selector stuff has been removed. +""" + drop_space(alg::InferenceAlgorithm) + +Return an `InferenceAlgorithm` like `alg`, but with all space information removed. +""" +function drop_space end + """ ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} @@ -133,6 +141,9 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain end end +# External samplers don't have notion of space to begin with. +drop_space(x::ExternalSampler) = x + DynamicPPL.getspace(::ExternalSampler) = () """ @@ -201,6 +212,8 @@ Algorithm for sampling from the prior. """ struct Prior <: InferenceAlgorithm end +drop_space(x::Prior) = x + function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index a815a1bc16..aec7b153a9 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -28,8 +28,6 @@ function varinfo(state::TuringState) return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) end varinfo(state::AbstractVarInfo) = state -varinfo(state::HMCState) = state.vi -varinfo(state::PGState) = state.vi # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index ebdfa041d7..816d90578a 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -26,6 +26,8 @@ function Emcee(n_walkers::Int, stretch_length=2.0) return Emcee{(),typeof(ensemble)}(ensemble) end +drop_space(alg::Emcee{space,E}) where {space,E} = Emcee{(),E}(alg.ensemble) + struct EmceeState{V<:AbstractVarInfo,S} vi::V states::S diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 395456ee5b..aa1a9fe380 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -25,6 +25,8 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() +drop_space(alg::ESS) = ESS() + # always accept in the first step function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index cd69cc7f17..4e9c7ac290 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -306,16 +306,17 @@ end Gibbs(; algs...) = Gibbs(NamedTuple(algs)) function Gibbs(algs::NamedTuple) return Gibbs( - map(s -> VarName{s}(), keys(algs)), map(wrap_algorithm_maybe, values(algs)) + map(s -> VarName{s}(), keys(algs)), + map(wrap_algorithm_maybe ∘ drop_space, values(algs)), ) end # AbstractDict function Gibbs(algs::AbstractDict) - return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) + return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe ∘ drop_space, values(algs))) end function Gibbs(algs::Pair...) - return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) + return Gibbs(map(first, algs), map(wrap_algorithm_maybe ∘ drop_space, map(last, algs))) end # The below two constructors only provide backwards compatibility with the constructor of @@ -339,7 +340,7 @@ function Gibbs(algs::InferenceAlgorithm...) "`Gibbs(@varname(x) => NUTS(), @varname(x) => NUTS(), @varname(y) => MH())`" ) Base.depwarn(msg, :Gibbs) - return Gibbs(varnames, map(wrap_algorithm_maybe, algs)) + return Gibbs(varnames, map(wrap_algorithm_maybe ∘ drop_space, algs)) end function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) @@ -410,12 +411,6 @@ function DynamicPPL.initialstep( kwargs..., ) new_vi_local = varinfo(new_state_local) - # TODO(mhauru) Remove the below loop once samplers no longer depend on selectors. - # For some reason not having this in place was causing trouble for ESS, but not for - # other samplers. I didn't get to the bottom of it. - for vn in keys(new_vi_local) - DynamicPPL.setgid!(new_vi_local, sampler_local.selector, vn) - end # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, context_local.global_varinfo[]) @@ -539,12 +534,6 @@ function gibbs_step_inner( # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, global_vi) varinfo_local = subset(global_vi, varnames_local) - # TODO(mhauru) Remove the below loop once samplers no longer depend on selectors. - # For some reason not having this in place was causing trouble for ESS, but not for - # other samplers. I didn't get to the bottom of it. - for vn in keys(varinfo_local) - DynamicPPL.setgid!(varinfo_local, sampler_local.selector, vn) - end # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not # sampled by other samplers, we don't need to `setparams`, but could rather simply diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 5887feb5e6..7770b822f3 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -21,6 +21,8 @@ end ### Hamiltonian Monte Carlo samplers. ### +varinfo(state::HMCState) = state.vi + """ HMC(ϵ::Float64, n_leapfrog::Int; adtype::ADTypes.AbstractADType = AutoForwardDiff()) @@ -76,6 +78,10 @@ function HMC( return HMC(ϵ, n_leapfrog, metricT, space; adtype=adtype) end +function drop_space(alg::HMC{AD,space,metricT}) where {AD,space,metricT} + return HMC{AD,(),metricT}(alg.ϵ, alg.n_leapfrog, alg.adtype) +end + DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() # Handle setting `nadapts` and `discard_initial` @@ -376,6 +382,10 @@ function HMCDA( return HMCDA(n_adapts, δ, λ, init_ϵ, metricT, space; adtype=adtype) end +function drop_space(alg::HMCDA{AD,space,metricT}) where {AD,space,metricT} + return HMCDA{AD,(),metricT}(alg.n_adapts, alg.δ, alg.λ, alg.ϵ, alg.adtype) +end + """ NUTS(n_adapts::Int, δ::Float64; max_depth::Int=10, Δ_max::Float64=1000.0, init_ϵ::Float64=0.0; adtype::ADTypes.AbstractADType=AutoForwardDiff() @@ -453,6 +463,12 @@ function NUTS(; kwargs...) return NUTS(-1, 0.65; kwargs...) end +function drop_space(alg::NUTS{AD,space,metricT}) where {AD,space,metricT} + return NUTS{AD,(),metricT}( + alg.n_adapts, alg.δ, alg.max_depth, alg.Δ_max, alg.ϵ, alg.adtype + ) +end + for alg in (:HMC, :HMCDA, :NUTS) @eval getmetricT(::$alg{<:Any,<:Any,metricT}) where {metricT} = metricT end diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 0fe6e10535..083bc7bc3a 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -28,6 +28,8 @@ struct IS{space} <: InferenceAlgorithm end IS() = IS{()}() +drop_space(alg::IS) = IS() + DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 8028588c69..8a9c19a4e8 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -145,6 +145,8 @@ function MH(space...) return MH{tuple(syms...),typeof(proposals)}(proposals) end +drop_space(alg::MH{space,P}) where {space,P} = MH{(),P}(alg.proposals) + # Some of the proposals require working in unconstrained space. transform_maybe(proposal::AMH.Proposal) = proposal function transform_maybe(proposal::AMH.RandomWalkProposal) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index a2b6757204..c5abb56f1d 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -45,6 +45,8 @@ end SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space) +drop_space(alg::SMC{space,R}) where {space,R} = SMC{(),R}(alg.resampler) + struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T @@ -220,6 +222,8 @@ function PG(nparticles::Int, space::Tuple) return PG(nparticles, AdvancedPS.ResampleWithESSThreshold(), space) end +drop_space(alg::PG{space,R}) where {space,R} = PG{(),R}(alg.nparticles, alg.resampler) + """ CSMC(...) @@ -241,6 +245,8 @@ struct PGState rng::Random.AbstractRNG end +varinfo(state::PGState) = state.vi + function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) theta = getparams(model, vi) diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index fbc4b11868..c79337c500 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -49,6 +49,10 @@ function SGHMC( ) end +function drop_space(alg::SGHMC{AD,space,T}) where {AD,space,T} + return SGHMC{AD,(),T}(alg.learning_rate, alg.momentum_decay, alg.adtype) +end + struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} logdensity::L vi::V @@ -130,6 +134,10 @@ struct SGLD{AD,space,S} <: StaticHamiltonian adtype::AD end +function drop_space(alg::SGLD{AD,space,S}) where {AD,space,S} + return SGLD{AD,(),S}(alg.stepsize, alg.adtype) +end + struct PolynomialStepsize{T<:Real} "Constant scale factor of the step size." a::T From 053ecfc52edd19f6b7f01982811ae5cd002e73bb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 13 Nov 2024 16:43:36 +0000 Subject: [PATCH 42/70] Fix calling of child context in GibbsContext --- src/mcmc/gibbs.jl | 51 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 4e9c7ac290..fd32e9b541 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -126,10 +126,17 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. - value = get_conditioned_gibbs(context, vn) - # TODO(mhauru) Is the call to logpdf correct if context.context is not - # DefaultContext? - value, logpdf(right, value), vi + child = DynamicPPL.childcontext(context) + if child isa SamplingContext + # TODO(mhauru) Would it ever be valid to have a SamplingContext as the child? + # We could just raise a warning, or optionally go down the stack of contexts + # skipping all SamplingContexts. The erroring is being conservative. + error("GibbsContext has a SamplingContext as its child.") + end + value, lp, _ = DynamicPPL.tilde_assume( + child, right, vn, get_global_varinfo(context) + ) + value, lp, vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add @@ -157,9 +164,15 @@ function DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) elseif has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? - value, logpdf(right, value), vi + child = DynamicPPL.childcontext(context) + if child isa SamplingContext + # TODO(mhauru) See comment in the method above. + error("GibbsContext has a SamplingContext as its child.") + end + value, lp, _ = DynamicPPL.tilde_assume( + child, right, vn, get_global_varinfo(context) + ) + value, lp, vi else value, lp, new_global_vi = DynamicPPL.tilde_assume( rng, @@ -198,9 +211,15 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi return if is_target_varname(context, vns) DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) elseif has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? - value, broadcast_logpdf(right, value), vi + child = DynamicPPL.childcontext(context) + if child isa SamplingContext + # TODO(mhauru) See comment in the method above. + error("GibbsContext has a SamplingContext as its child.") + end + value, lp, _ = DynamicPPL.dot_tilde_assume( + child, right, left, vns, get_global_varinfo(context) + ) + value, lp, vi else prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( @@ -225,9 +244,15 @@ function DynamicPPL.dot_tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi ) elseif has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - # TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext? - value, broadcast_logpdf(right, value), vi + child = DynamicPPL.childcontext(context) + if child isa SamplingContext + # TODO(mhauru) See comment in the method above. + error("GibbsContext has a SamplingContext as its child.") + end + value, lp, _ = DynamicPPL.dot_tilde_assume( + child, right, left, vns, get_global_varinfo(context) + ) + value, lp, vi else prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( From d53350d767732154ec31aee2ef649b29c12a89d3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 15 Nov 2024 18:57:10 +0000 Subject: [PATCH 43/70] Fix Selectors and type stability of Gibbs --- src/mcmc/gibbs.jl | 81 +++++++++++++++++++++++----------------------- test/mcmc/gibbs.jl | 3 +- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index fd32e9b541..4008a7b906 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -31,25 +31,29 @@ isgibbscomponent(::AdvancedMH.MetropolisHastings) = true # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. """ - GibbsContext(target_varnames, global_varinfo, context) + GibbsContext{VNs}(global_varinfo, context) A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. +`VNs` is a `Val` type of a tuple of symbols for `VarName`s that the current component +sampler is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume` +calls to its child context. For other variables, their values will be fixed to the values +they have in `global_varinfo`. + +The naive implementation of `GibbsContext` would simply have a field `target_varnames` that +would be a collection of `VarName`s that the current component sampler is sampling. The +reason we instead have a `Val` type parameter listing `Symbol`s is to allow +`is_target_varname` to benefit from compile time constant propagation. This is important +for type stability of `tilde_assume`. + # Fields $(FIELDS) """ -struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: +struct GibbsContext{VNs<:Val,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext """ - a collection of `VarName`s that the current component sampler is sampling. - For them, `GibbsContext` will just pass tilde_assume calls to its child context. - For other variables, their values will be fixed to the values they have in - `global_varinfo`. - """ - target_varnames::VNs - """ a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both those fixed and those being sampled. We use a `Ref` because this field may need to be updated if new variables are introduced. @@ -59,6 +63,19 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont the child context that tilde calls will eventually be passed onto. """ context::Ctx + + function GibbsContext{VNs}(global_varinfo, context) where {VNs} + return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context) + end + + # If the first argument is not already a Val, convert it to one. + function GibbsContext(target_varnames, global_varinfo, context) + # TODO(mhauru) Add a check that all target_varnames have identity lenses. + vn_sym = Val(tuple((DynamicPPL.getsym(vn) for vn in target_varnames)...)) + return new{typeof(vn_sym),typeof(global_varinfo),typeof(context)}( + global_varinfo, context + ) + end end function GibbsContext(target_varnames, global_varinfo) @@ -67,10 +84,8 @@ end DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context -function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) - return GibbsContext( - context.target_varnames, Ref(context.global_varinfo[]), childcontext - ) +function DynamicPPL.setchildcontext(context::GibbsContext{VNs}, childcontext) where {VNs} + return GibbsContext{VNs}(Ref(context.global_varinfo[]), childcontext) end get_global_varinfo(context::GibbsContext) = context.global_varinfo[] @@ -103,11 +118,11 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa end function is_target_varname(context::GibbsContext, vn::VarName) - return Iterators.any( - Iterators.map(target -> subsumes(target, vn), context.target_varnames) - ) + return is_target_varname(context, DynamicPPL.getsym(vn)) end +is_target_varname(::GibbsContext{Val{T}}, vn_symbol::Symbol) where {T} = vn_symbol in T + function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) if (num_target != 0) && (num_target != length(vns)) @@ -187,24 +202,6 @@ function DynamicPPL.tilde_assume( end end -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -# Needed to support broadcasting over columns for `MultivariateDistribution`s. -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - # Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf. # See comments there for more details. function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) @@ -221,10 +218,9 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi ) value, lp, vi else - prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( DynamicPPL.childcontext(context), - prior_sampler, + DynamicPPL.SampleFromPrior(), right, left, vns, @@ -254,11 +250,10 @@ function DynamicPPL.dot_tilde_assume( ) value, lp, vi else - prior_sampler = DynamicPPL.SampleFromPrior() value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( rng, DynamicPPL.childcontext(context), - prior_sampler, + DynamicPPL.SampleFromPrior(), right, left, vns, @@ -294,10 +289,14 @@ function make_conditional( return DynamicPPL.contextualize(model, gibbs_context), gibbs_context end -# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` -# or an `AbstractInferenceAlgorithm`. wrap_algorithm_maybe(x) = x -wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) +# All samplers are given the same Selector, so that they will also sample all variables +# given to them by the Gibbs sampler. This avoids conflicts between the new and the old way +# of choosing which sampler to use. +function wrap_algorithm_maybe(x::DynamicPPL.Sampler) + return DynamicPPL.Sampler(x.alg, DynamicPPL.Selector(0)) +end +wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) """ Gibbs diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index c8441c0517..75fefa6210 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -69,7 +69,8 @@ has_dot_assume(::DynamicPPL.Model) = true n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) ), ) - for target_vns in target_vn_combinations + + @testset "$(target_vns)" for target_vns in target_vn_combinations global_varinfo = DynamicPPL.VarInfo(model) target_vns = collect(target_vns) local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) From 5d401e1ec49b08ae76e620137904684c994df13f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 15 Nov 2024 18:57:25 +0000 Subject: [PATCH 44/70] Fix broken short circuit in MH --- src/mcmc/mh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 8a9c19a4e8..edd46a4572 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -252,7 +252,7 @@ function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVa end @generated function _val_tuple(vi::VarInfo, vns::NamedTuple{names}) where {names} - isempty(names) === 0 && return :(NamedTuple()) + isempty(names) && return :(NamedTuple()) expr = Expr(:tuple) expr.args = Any[ :( From d57c97c236729b06515be035bcaa8c4c376c1378 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 18 Nov 2024 15:17:02 +0000 Subject: [PATCH 45/70] Stop unnecessary use of Val in GibbsContext --- src/mcmc/gibbs.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 4008a7b906..d16b92ed90 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -37,21 +37,21 @@ A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. -`VNs` is a `Val` type of a tuple of symbols for `VarName`s that the current component +`VNs` is a a tuple of symbols for `VarName`s that the current component sampler is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume` calls to its child context. For other variables, their values will be fixed to the values they have in `global_varinfo`. The naive implementation of `GibbsContext` would simply have a field `target_varnames` that would be a collection of `VarName`s that the current component sampler is sampling. The -reason we instead have a `Val` type parameter listing `Symbol`s is to allow +reason we instead have a `Tuple` type parameter listing `Symbol`s is to allow `is_target_varname` to benefit from compile time constant propagation. This is important for type stability of `tilde_assume`. # Fields $(FIELDS) """ -struct GibbsContext{VNs<:Val,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: +struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext """ a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both @@ -68,13 +68,16 @@ struct GibbsContext{VNs<:Val,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.Abstrac return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context) end - # If the first argument is not already a Val, convert it to one. function GibbsContext(target_varnames, global_varinfo, context) - # TODO(mhauru) Add a check that all target_varnames have identity lenses. - vn_sym = Val(tuple((DynamicPPL.getsym(vn) for vn in target_varnames)...)) - return new{typeof(vn_sym),typeof(global_varinfo),typeof(context)}( - global_varinfo, context - ) + if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames) + msg = + "All Gibbs target variables must have identity lenses. " * + "For example, you can't have `@varname(x.a[1])` as a target variable, " * + "only `@varname(x)`." + error(msg) + end + vn_sym = tuple(unique((DynamicPPL.getsym(vn) for vn in target_varnames))...) + return new{vn_sym,typeof(global_varinfo),typeof(context)}(global_varinfo, context) end end @@ -121,7 +124,7 @@ function is_target_varname(context::GibbsContext, vn::VarName) return is_target_varname(context, DynamicPPL.getsym(vn)) end -is_target_varname(::GibbsContext{Val{T}}, vn_symbol::Symbol) where {T} = vn_symbol in T +is_target_varname(::GibbsContext{T}, vn_symbol::Symbol) where {T} = vn_symbol in T function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) From 5bebff2f6cb5a540d51c7db182a32dd7d836242f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 18 Nov 2024 15:21:04 +0000 Subject: [PATCH 46/70] Enforce GibbsContext being next to a leaf --- src/mcmc/gibbs.jl | 68 ++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d16b92ed90..4fffeef1fc 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -65,10 +65,16 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont context::Ctx function GibbsContext{VNs}(global_varinfo, context) where {VNs} + if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf) + error("GibbsContext can only wrap a leaf context, not a $(context).") + end return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context) end function GibbsContext(target_varnames, global_varinfo, context) + if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf) + error("GibbsContext can only wrap a leaf context, not a $(context).") + end if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames) msg = "All Gibbs target variables must have identity lenses. " * @@ -139,20 +145,14 @@ end # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + child_context = DynamicPPL.childcontext(context) return if is_target_varname(context, vn) # Fall back to the default behavior. - DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) + DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. - child = DynamicPPL.childcontext(context) - if child isa SamplingContext - # TODO(mhauru) Would it ever be valid to have a SamplingContext as the child? - # We could just raise a warning, or optionally go down the stack of contexts - # skipping all SamplingContexts. The erroring is being conservative. - error("GibbsContext has a SamplingContext as its child.") - end value, lp, _ = DynamicPPL.tilde_assume( - child, right, vn, get_global_varinfo(context) + child_context, right, vn, get_global_varinfo(context) ) value, lp, vi else @@ -161,7 +161,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. value, lp, new_global_vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), + child_context, DynamicPPL.SampleFromPrior(), right, vn, @@ -177,24 +177,18 @@ function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) # See comment in the above, rng-less version of this method for an explanation. + child_context = DynamicPPL.childcontext(context) return if is_target_varname(context, vn) - DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) + DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - child = DynamicPPL.childcontext(context) - if child isa SamplingContext - # TODO(mhauru) See comment in the method above. - error("GibbsContext has a SamplingContext as its child.") - end value, lp, _ = DynamicPPL.tilde_assume( - child, right, vn, get_global_varinfo(context) + child_context, right, vn, get_global_varinfo(context) ) value, lp, vi else value, lp, new_global_vi = DynamicPPL.tilde_assume( rng, - DynamicPPL.childcontext(context), + child_context, DynamicPPL.SampleFromPrior(), right, vn, @@ -208,21 +202,17 @@ end # Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf. # See comments there for more details. function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + child_context = DynamicPPL.childcontext(context) return if is_target_varname(context, vns) - DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) + DynamicPPL.dot_tilde_assume(child_context, right, left, vns, vi) elseif has_conditioned_gibbs(context, vns) - child = DynamicPPL.childcontext(context) - if child isa SamplingContext - # TODO(mhauru) See comment in the method above. - error("GibbsContext has a SamplingContext as its child.") - end value, lp, _ = DynamicPPL.dot_tilde_assume( - child, right, left, vns, get_global_varinfo(context) + child_context, right, left, vns, get_global_varinfo(context) ) value, lp, vi else value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( - DynamicPPL.childcontext(context), + child_context, DynamicPPL.SampleFromPrior(), right, left, @@ -238,24 +228,18 @@ end function DynamicPPL.dot_tilde_assume( rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi ) + child_context = DynamicPPL.childcontext(context) return if is_target_varname(context, vns) - DynamicPPL.dot_tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi - ) + DynamicPPL.dot_tilde_assume(rng, child_context, sampler, right, left, vns, vi) elseif has_conditioned_gibbs(context, vns) - child = DynamicPPL.childcontext(context) - if child isa SamplingContext - # TODO(mhauru) See comment in the method above. - error("GibbsContext has a SamplingContext as its child.") - end value, lp, _ = DynamicPPL.dot_tilde_assume( - child, right, left, vns, get_global_varinfo(context) + child_context, right, left, vns, get_global_varinfo(context) ) value, lp, vi else value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( rng, - DynamicPPL.childcontext(context), + child_context, DynamicPPL.SampleFromPrior(), right, left, @@ -288,7 +272,13 @@ because evaluation can mutate its `global_varinfo` field, which we need to acces function make_conditional( model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo ) - gibbs_context = GibbsContext(target_variables, Ref(varinfo), model.context) + # Insert the `GibbsContext` just before the leaf. + # 1. Extract the `leafcontext` from `model` and wrap in `GibbsContext`. + gibbs_context_inner = GibbsContext( + target_variables, Ref(varinfo), DynamicPPL.leafcontext(model.context) + ) + # 2. Set the leaf context to be the `GibbsContext` wrapping `leafcontext(model.context)`. + gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) return DynamicPPL.contextualize(model, gibbs_context), gibbs_context end From fdf1347062fff3ac98e2af79caf6a165f78f470c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 19 Nov 2024 11:15:10 +0000 Subject: [PATCH 47/70] Fix setparams_varinfo!! for ESS --- src/mcmc/gibbs.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 4fffeef1fc..97afbd967b 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -481,7 +481,10 @@ function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::VarInfo, params::AbstractVarInfo + model::DynamicPPL.Model, + sampler::Sampler{<:MH}, + state::AbstractVarInfo, + params::AbstractVarInfo, ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. @@ -493,14 +496,14 @@ end function setparams_varinfo!!( model::DynamicPPL.Model, sampler::Sampler{<:ESS}, - state::VarInfo, + state::AbstractVarInfo, params::AbstractVarInfo, ) # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. - # Note the use of LikelihoodContext, regardless of what context `model` has. This is - # specific to ESS as a sampler. - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.LikelihoodContext())) + # update its logprob. To do this, we have to call evaluate!! with the sampler, rather + # than just a context, because ESS is peculiar in how it uses LikelihoodContext for + # some variables and DefaultContext for others. + return last(DynamicPPL.evaluate!!(model, params, sampler)) end function setparams_varinfo!!( From 8e8ed0dd8b33f7a7aeeada42eaf809ee0608382e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 19 Nov 2024 11:40:42 +0000 Subject: [PATCH 48/70] Fix a small Gibbs bug --- src/mcmc/gibbs.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 97afbd967b..a460791a6d 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -6,7 +6,6 @@ Return a boolean indicating whether `alg` is a valid component for a Gibbs sampl Defaults to `false` if no method has been defined for a particular algorithm type. """ isgibbscomponent(::InferenceAlgorithm) = false -isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) isgibbscomponent(spl::Sampler) = isgibbscomponent(spl.alg) isgibbscomponent(::ESS) = true @@ -15,6 +14,8 @@ isgibbscomponent(::HMCDA) = true isgibbscomponent(::NUTS) = true isgibbscomponent(::MH) = true isgibbscomponent(::PG) = true + +isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) isgibbscomponent(::AdvancedHMC.HMC) = true isgibbscomponent(::AdvancedMH.MetropolisHastings) = true @@ -279,7 +280,7 @@ function make_conditional( ) # 2. Set the leaf context to be the `GibbsContext` wrapping `leafcontext(model.context)`. gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) - return DynamicPPL.contextualize(model, gibbs_context), gibbs_context + return DynamicPPL.contextualize(model, gibbs_context), gibbs_context_inner end wrap_algorithm_maybe(x) = x @@ -430,7 +431,7 @@ function DynamicPPL.initialstep( new_vi_local = varinfo(new_state_local) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. - vi = merge(vi, context_local.global_varinfo[]) + vi = merge(vi, get_global_varinfo(context_local)) # Merge the new values for all the variables sampled by the current sampler. vi = merge(vi, new_vi_local) push!(states, new_state_local) From 2cced8be4fdaae580848a6990bb90e4ee70ac8df Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 19 Nov 2024 12:05:08 +0000 Subject: [PATCH 49/70] Fix Gibbs sampler test --- test/mcmc/gibbs.jl | 570 ++++++++++++++++++++++----------------------- 1 file changed, 282 insertions(+), 288 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 75fefa6210..c8eaefd70a 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -43,302 +43,294 @@ const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::DynamicPPL.Model) = true -@testset "GibbsContext" begin - @testset "type stability" begin - # A test model that has multiple features in one package: - # Floats, Ints, arguments, observations, loops, dot_tildes. - @model function test_model(obs1, obs2, num_vars, mean) - variance ~ Exponential(2) - z = Vector{Float64}(undef, num_vars) - z .~ truncated(Normal(mean, variance); lower=1) - y = Vector{Int64}(undef, num_vars) - for i in 1:num_vars - y[i] ~ Poisson(Int(round(z[i]))) - end - s = sum(y) - sum(z) - obs1 ~ Normal(s, 1) - obs2 ~ Poisson(y[3]) - return obs1, obs2, variance, z, y, s - end - - model = test_model(1.2, 2, 10, 2.5) - all_varnames = DynamicPPL.VarName[@varname(variance), @varname(z), @varname(y)] - # All combinations of elements in all_varnames. - target_vn_combinations = Iterators.flatten( - Iterators.map( - n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) - ), - ) - - @testset "$(target_vns)" for target_vns in target_vn_combinations - global_varinfo = DynamicPPL.VarInfo(model) - target_vns = collect(target_vns) - local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) - ctx = Turing.Inference.GibbsContext( - target_vns, Ref(global_varinfo), Turing.DefaultContext() - ) - - # Check that the correct varnames are conditioned, and that getting their - # values is type stable when the varinfo is. - for k in keys(global_varinfo) - is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) - @test Turing.Inference.is_target_varname(ctx, k) == is_target - if !is_target - @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) - end - end - - # Check the type stability also in the dot_tilde pipeline. - for k in all_varnames - # The map(identity, ...) part is there to concretise the eltype. - subkeys = map( - identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) - ) - is_target = (k in target_vns) - @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target - if !is_target - @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) - end - end - - # Check that evaluate!! and the result it returns are type stable. - conditioned_model = DynamicPPL.contextualize(model, ctx) - _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( - conditioned_model, local_varinfo - ) - for k in keys(post_eval_varinfo) - @inferred post_eval_varinfo[k] - end - end - end -end - -@testset "Invalid Gibbs constructor" begin - # More samplers than varnames or vice versa - @test_throws ArgumentError Gibbs((@varname(s), @varname(m)), (NUTS(), NUTS(), NUTS())) - @test_throws ArgumentError Gibbs( - (@varname(s), @varname(m), @varname(x)), (NUTS(), NUTS()) - ) - # Invalid samplers - @test_throws ArgumentError Gibbs(@varname(s) => IS()) - @test_throws ArgumentError Gibbs(@varname(s) => Emcee(10, 2.0)) - @test_throws ArgumentError Gibbs( - @varname(s) => SGHMC(; learning_rate=0.01, momentum_decay=0.1) - ) - @test_throws ArgumentError Gibbs( - @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) - ) -end +# @testset "GibbsContext" begin +# @testset "type stability" begin +# # A test model that has multiple features in one package: +# # Floats, Ints, arguments, observations, loops, dot_tildes. +# @model function test_model(obs1, obs2, num_vars, mean) +# variance ~ Exponential(2) +# z = Vector{Float64}(undef, num_vars) +# z .~ truncated(Normal(mean, variance); lower=1) +# y = Vector{Int64}(undef, num_vars) +# for i in 1:num_vars +# y[i] ~ Poisson(Int(round(z[i]))) +# end +# s = sum(y) - sum(z) +# obs1 ~ Normal(s, 1) +# obs2 ~ Poisson(y[3]) +# return obs1, obs2, variance, z, y, s +# end + +# model = test_model(1.2, 2, 10, 2.5) +# all_varnames = DynamicPPL.VarName[@varname(variance), @varname(z), @varname(y)] +# # All combinations of elements in all_varnames. +# target_vn_combinations = Iterators.flatten( +# Iterators.map( +# n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) +# ), +# ) + +# @testset "$(target_vns)" for target_vns in target_vn_combinations +# global_varinfo = DynamicPPL.VarInfo(model) +# target_vns = collect(target_vns) +# local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) +# ctx = Turing.Inference.GibbsContext( +# target_vns, Ref(global_varinfo), Turing.DefaultContext() +# ) + +# # Check that the correct varnames are conditioned, and that getting their +# # values is type stable when the varinfo is. +# for k in keys(global_varinfo) +# is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) +# @test Turing.Inference.is_target_varname(ctx, k) == is_target +# if !is_target +# @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) +# end +# end + +# # Check the type stability also in the dot_tilde pipeline. +# for k in all_varnames +# # The map(identity, ...) part is there to concretise the eltype. +# subkeys = map( +# identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) +# ) +# is_target = (k in target_vns) +# @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target +# if !is_target +# @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) +# end +# end + +# # Check that evaluate!! and the result it returns are type stable. +# conditioned_model = DynamicPPL.contextualize(model, ctx) +# _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( +# conditioned_model, local_varinfo +# ) +# for k in keys(post_eval_varinfo) +# @inferred post_eval_varinfo[k] +# end +# end +# end +# end + +# @testset "Invalid Gibbs constructor" begin +# # More samplers than varnames or vice versa +# @test_throws ArgumentError Gibbs((@varname(s), @varname(m)), (NUTS(), NUTS(), NUTS())) +# @test_throws ArgumentError Gibbs( +# (@varname(s), @varname(m), @varname(x)), (NUTS(), NUTS()) +# ) +# # Invalid samplers +# @test_throws ArgumentError Gibbs(@varname(s) => IS()) +# @test_throws ArgumentError Gibbs(@varname(s) => Emcee(10, 2.0)) +# @test_throws ArgumentError Gibbs( +# @varname(s) => SGHMC(; learning_rate=0.01, momentum_decay=0.1) +# ) +# @test_throws ArgumentError Gibbs( +# @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) +# ) +# end @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @testset "Deprecated Gibbs constructors" begin - N = 10 - @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) - @test_deprecated s2 = Gibbs(PG(10, :s, :m)) - @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) - for s in (s1, s2, s3, s4, s5, s6, s7) - @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" - end - - # Check that the samplers work despite using the deprecated constructor. - sample(gdemo_default, s1, N) - sample(gdemo_default, s2, N) - sample(gdemo_default, s3, N) - sample(gdemo_default, s4, N) - sample(gdemo_default, s5, N) - sample(gdemo_default, s6, N) - sample(gdemo_default, s7, N) - - g = Turing.Sampler(s3, gdemo_default) - @test sample(gdemo_default, g, N) isa MCMCChains.Chains - end - - @testset "Gibbs constructors" begin - # Create Gibbs samplers with various configurations and ways of passing the - # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. - N = 10 - # Two variables being sampled by one sampler. - s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) - s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) - # One variable per sampler, using the keyword arg interface. - s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) - # As above but using a Dict of VarNames. - s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) - # As above but different samplers. - s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) - s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) - s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) - # Multiple instnaces of the same sampler. This implements running, in this case, - # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. - s8 = begin - hmc = HMC(0.1, 5; adtype=adbackend) - pg = PG(10) - vns = @varname(s) - vnm = @varname(m) - Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) - end - for s in (s1, s2, s3, s4, s5, s6, s7, s8) - @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" - end - - sample(gdemo_default, s1, N) - sample(gdemo_default, s2, N) - sample(gdemo_default, s3, N) - sample(gdemo_default, s4, N) - sample(gdemo_default, s5, N) - sample(gdemo_default, s6, N) - sample(gdemo_default, s7, N) - sample(gdemo_default, s8, N) - - g = Turing.Sampler(s3, gdemo_default) - @test sample(gdemo_default, g, N) isa MCMCChains.Chains - end + # @testset "Deprecated Gibbs constructors" begin + # N = 10 + # @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) + # @test_deprecated s2 = Gibbs(PG(10, :s, :m)) + # @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + # @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + # @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + # @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) + # @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) + # for s in (s1, s2, s3, s4, s5, s6, s7) + # @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" + # end + + # # Check that the samplers work despite using the deprecated constructor. + # sample(gdemo_default, s1, N) + # sample(gdemo_default, s2, N) + # sample(gdemo_default, s3, N) + # sample(gdemo_default, s4, N) + # sample(gdemo_default, s5, N) + # sample(gdemo_default, s6, N) + # sample(gdemo_default, s7, N) + + # g = Turing.Sampler(s3, gdemo_default) + # @test sample(gdemo_default, g, N) isa MCMCChains.Chains + # end + + # @testset "Gibbs constructors" begin + # # Create Gibbs samplers with various configurations and ways of passing the + # # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. + # N = 10 + # # Two variables being sampled by one sampler. + # s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) + # s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) + # # One variable per sampler, using the keyword arg interface. + # s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + # # As above but using a Dict of VarNames. + # s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + # # As above but different samplers. + # s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) + # s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) + # s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + # # Multiple instnaces of the same sampler. This implements running, in this case, + # # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. + # s8 = begin + # hmc = HMC(0.1, 5; adtype=adbackend) + # pg = PG(10) + # vns = @varname(s) + # vnm = @varname(m) + # Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) + # end + # for s in (s1, s2, s3, s4, s5, s6, s7, s8) + # @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" + # end + + # sample(gdemo_default, s1, N) + # sample(gdemo_default, s2, N) + # sample(gdemo_default, s3, N) + # sample(gdemo_default, s4, N) + # sample(gdemo_default, s5, N) + # sample(gdemo_default, s6, N) + # sample(gdemo_default, s7, N) + # sample(gdemo_default, s8, N) + + # g = Turing.Sampler(s3, gdemo_default) + # @test sample(gdemo_default, g, N) isa MCMCChains.Chains + # end @testset "gibbs inference" begin - Random.seed!(100) - alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:m], [7 / 6]; atol=0.15) - # Be more relaxed with the tolerance of the variance. - check_numerical(chain, [:s], [49 / 24]; atol=0.35) - - Random.seed!(100) - - alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - alg = Gibbs(; s=CSMC(15), m=ESS()) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - alg = CSMC(15) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - Random.seed!(200) - gibbs = Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), - ) - chain = sample(MoGtest_default, gibbs, 10_000) - check_MoGtest_default(chain; atol=0.15) - - Random.seed!(200) - for alg in [ - # The new syntax for specifying a sampler to run twice for one variable. - Gibbs( - @varname(s) => MH(), - @varname(s) => MH(), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - ), - Gibbs( - @varname(s) => MH(), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - ), - ] - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_gdemo(chain; atol=0.15) - end + # Random.seed!(100) + # alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) + # chain = sample(gdemo(1.5, 2.0), alg, 10_000) + # check_numerical(chain, [:m], [7 / 6]; atol=0.15) + # # Be more relaxed with the tolerance of the variance. + # check_numerical(chain, [:s], [49 / 24]; atol=0.35) + + # Random.seed!(100) + + # alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + # chain = sample(gdemo(1.5, 2.0), alg, 10_000) + # check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + + # alg = Gibbs(; s=CSMC(15), m=ESS()) + # chain = sample(gdemo(1.5, 2.0), alg, 10_000) + # check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + + # alg = CSMC(15) + # chain = sample(gdemo(1.5, 2.0), alg, 10_000) + # check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + + # Random.seed!(200) + # gibbs = Gibbs( + # (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + # (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + # ) + # chain = sample(MoGtest_default, gibbs, 10_000) + # check_MoGtest_default(chain; atol=0.15) + + # Random.seed!(200) + # for alg in [ + # # The new syntax for specifying a sampler to run twice for one variable. + # Gibbs( + # @varname(s) => MH(), + # @varname(s) => MH(), + # @varname(m) => HMC(0.2, 4; adtype=adbackend), + # ), + # Gibbs( + # @varname(s) => MH(), + # @varname(m) => HMC(0.2, 4; adtype=adbackend), + # @varname(m) => HMC(0.2, 4; adtype=adbackend), + # ), + # ] + # chain = sample(gdemo(1.5, 2.0), alg, 10_000) + # check_gdemo(chain; atol=0.15) + # end end - @testset "transitions" begin - @model function gdemo_copy() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - 1.5 ~ Normal(m, sqrt(s)) - 2.0 ~ Normal(m, sqrt(s)) - return s, m - end - model = gdemo_copy() - - @nospecialize function AbstractMCMC.bundle_samples( - samples::Vector, - ::typeof(model), - ::Turing.Sampler{<:Gibbs}, - state, - ::Type{MCMCChains.Chains}; - kwargs..., - ) - samples isa Vector{<:Inference.Transition} || error("incorrect transitions") - return nothing - end - - function callback(rng, model, sampler, sample, state, i; kwargs...) - sample isa Inference.Transition || error("incorrect sample") - return nothing - end - - alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) - sample(model, alg, 100; callback=callback) - end - - @testset "dynamic model" begin - # TODO(mhauru) We should check that the results of the sampling are correct. - # Currently we just check that this doesn't crash. - @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} - N = length(y) - rpm = DirichletProcess(alpha) - - z = zeros(Int, N) - cluster_counts = zeros(Int, N) - fill!(cluster_counts, 0) - - for i in 1:N - z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts) - cluster_counts[z[i]] += 1 - end - - Kmax = findlast(!iszero, cluster_counts) - m = M(undef, Kmax) - for k in 1:Kmax - m[k] ~ Normal(1.0, 1.0) - end - end - model = imm(Random.randn(100), 1.0) - # https://github.com/TuringLang/Turing.jl/issues/1725 - # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); - sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) - end - - @testset "dynamic model with dot tilde" begin - @model function dynamic_model_with_dot_tilde( - num_zs=10, ::Type{M}=Vector{Float64} - ) where {M} - z = M(undef, num_zs) - z .~ Poisson(1.0) - num_ms = sum(z) - m = M(undef, num_ms) - return m .~ Normal(1.0, 1.0) - end - model = dynamic_model_with_dot_tilde() - # TODO(mhauru) This is broken because of - # https://github.com/TuringLang/DynamicPPL.jl/issues/700. - @test_broken ( - sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); - true - ) - end + # @testset "transitions" begin + # @model function gdemo_copy() + # s ~ InverseGamma(2, 3) + # m ~ Normal(0, sqrt(s)) + # 1.5 ~ Normal(m, sqrt(s)) + # 2.0 ~ Normal(m, sqrt(s)) + # return s, m + # end + # model = gdemo_copy() + + # @nospecialize function AbstractMCMC.bundle_samples( + # samples::Vector, + # ::typeof(model), + # ::Turing.Sampler{<:Gibbs}, + # state, + # ::Type{MCMCChains.Chains}; + # kwargs..., + # ) + # samples isa Vector{<:Inference.Transition} || error("incorrect transitions") + # return nothing + # end + + # function callback(rng, model, sampler, sample, state, i; kwargs...) + # sample isa Inference.Transition || error("incorrect sample") + # return nothing + # end + + # alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + # sample(model, alg, 100; callback=callback) + # end + + # @testset "dynamic model" begin + # # TODO(mhauru) We should check that the results of the sampling are correct. + # # Currently we just check that this doesn't crash. + # @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} + # N = length(y) + # rpm = DirichletProcess(alpha) + + # z = zeros(Int, N) + # cluster_counts = zeros(Int, N) + # fill!(cluster_counts, 0) + + # for i in 1:N + # z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts) + # cluster_counts[z[i]] += 1 + # end + + # Kmax = findlast(!iszero, cluster_counts) + # m = M(undef, Kmax) + # for k in 1:Kmax + # m[k] ~ Normal(1.0, 1.0) + # end + # end + # model = imm(Random.randn(100), 1.0) + # # https://github.com/TuringLang/Turing.jl/issues/1725 + # # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); + # sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) + # end + + # @testset "dynamic model with dot tilde" begin + # @model function dynamic_model_with_dot_tilde( + # num_zs=10, ::Type{M}=Vector{Float64} + # ) where {M} + # z = M(undef, num_zs) + # z .~ Poisson(1.0) + # num_ms = sum(z) + # m = M(undef, num_ms) + # return m .~ Normal(1.0, 1.0) + # end + # model = dynamic_model_with_dot_tilde() + # # TODO(mhauru) This is broken because of + # # https://github.com/TuringLang/DynamicPPL.jl/issues/700. + # @test_broken ( + # sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); + # true + # ) + # end @testset "Demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end - samplers = [ - Turing.Gibbs(vns_s => NUTS(), vns_m => NUTS()), - Turing.Gibbs(vns_s => NUTS(), vns_m => HMC(0.01, 4)), + Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()), + Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)), ] if !has_dot_assume(model) @@ -346,8 +338,8 @@ end append!( samplers, [ - Turing.Gibbs(vns_s => HMC(0.01, 4), vns_m => MH()), - Turing.Gibbs(vns_s => MH(), vns_m => HMC(0.01, 4)), + Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()), + Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)), ], ) end @@ -385,7 +377,9 @@ end # Sampler to use for Gibbs components. sampler_inner = HMC(0.1, 32) - sampler = Turing.Gibbs(vns_s => sampler_inner, vns_m => sampler_inner) + sampler = Turing.Gibbs( + @varname(s) => sampler_inner, @varname(m) => sampler_inner + ) Random.seed!(42) chain = sample( model, From e1dbf6c7b158737ca7912f8a395598d59a65f728 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 20 Nov 2024 11:37:29 +0000 Subject: [PATCH 50/70] Add back tests that were accidentally commented out --- test/mcmc/gibbs.jl | 564 +++++++++++++++++++++++---------------------- 1 file changed, 289 insertions(+), 275 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index c8eaefd70a..91baf27731 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -43,287 +43,301 @@ const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::DynamicPPL.Model) = true -# @testset "GibbsContext" begin -# @testset "type stability" begin -# # A test model that has multiple features in one package: -# # Floats, Ints, arguments, observations, loops, dot_tildes. -# @model function test_model(obs1, obs2, num_vars, mean) -# variance ~ Exponential(2) -# z = Vector{Float64}(undef, num_vars) -# z .~ truncated(Normal(mean, variance); lower=1) -# y = Vector{Int64}(undef, num_vars) -# for i in 1:num_vars -# y[i] ~ Poisson(Int(round(z[i]))) -# end -# s = sum(y) - sum(z) -# obs1 ~ Normal(s, 1) -# obs2 ~ Poisson(y[3]) -# return obs1, obs2, variance, z, y, s -# end - -# model = test_model(1.2, 2, 10, 2.5) -# all_varnames = DynamicPPL.VarName[@varname(variance), @varname(z), @varname(y)] -# # All combinations of elements in all_varnames. -# target_vn_combinations = Iterators.flatten( -# Iterators.map( -# n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) -# ), -# ) - -# @testset "$(target_vns)" for target_vns in target_vn_combinations -# global_varinfo = DynamicPPL.VarInfo(model) -# target_vns = collect(target_vns) -# local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) -# ctx = Turing.Inference.GibbsContext( -# target_vns, Ref(global_varinfo), Turing.DefaultContext() -# ) - -# # Check that the correct varnames are conditioned, and that getting their -# # values is type stable when the varinfo is. -# for k in keys(global_varinfo) -# is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) -# @test Turing.Inference.is_target_varname(ctx, k) == is_target -# if !is_target -# @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) -# end -# end - -# # Check the type stability also in the dot_tilde pipeline. -# for k in all_varnames -# # The map(identity, ...) part is there to concretise the eltype. -# subkeys = map( -# identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) -# ) -# is_target = (k in target_vns) -# @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target -# if !is_target -# @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) -# end -# end - -# # Check that evaluate!! and the result it returns are type stable. -# conditioned_model = DynamicPPL.contextualize(model, ctx) -# _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( -# conditioned_model, local_varinfo -# ) -# for k in keys(post_eval_varinfo) -# @inferred post_eval_varinfo[k] -# end -# end -# end -# end - -# @testset "Invalid Gibbs constructor" begin -# # More samplers than varnames or vice versa -# @test_throws ArgumentError Gibbs((@varname(s), @varname(m)), (NUTS(), NUTS(), NUTS())) -# @test_throws ArgumentError Gibbs( -# (@varname(s), @varname(m), @varname(x)), (NUTS(), NUTS()) -# ) -# # Invalid samplers -# @test_throws ArgumentError Gibbs(@varname(s) => IS()) -# @test_throws ArgumentError Gibbs(@varname(s) => Emcee(10, 2.0)) -# @test_throws ArgumentError Gibbs( -# @varname(s) => SGHMC(; learning_rate=0.01, momentum_decay=0.1) -# ) -# @test_throws ArgumentError Gibbs( -# @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) -# ) -# end +@testset "GibbsContext" begin + @testset "type stability" begin + # A test model that has multiple features in one package: + # Floats, Ints, arguments, observations, loops, dot_tildes. + @model function test_model(obs1, obs2, num_vars, mean) + variance ~ Exponential(2) + z = Vector{Float64}(undef, num_vars) + z .~ truncated(Normal(mean, variance); lower=1) + y = Vector{Int64}(undef, num_vars) + for i in 1:num_vars + y[i] ~ Poisson(Int(round(z[i]))) + end + s = sum(y) - sum(z) + obs1 ~ Normal(s, 1) + obs2 ~ Poisson(y[3]) + return obs1, obs2, variance, z, y, s + end + + model = test_model(1.2, 2, 10, 2.5) + all_varnames = DynamicPPL.VarName[@varname(variance), @varname(z), @varname(y)] + # All combinations of elements in all_varnames. + target_vn_combinations = Iterators.flatten( + Iterators.map( + n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) + ), + ) + + @testset "$(target_vns)" for target_vns in target_vn_combinations + global_varinfo = DynamicPPL.VarInfo(model) + target_vns = collect(target_vns) + local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) + ctx = Turing.Inference.GibbsContext( + target_vns, Ref(global_varinfo), Turing.DefaultContext() + ) + + # Check that the correct varnames are conditioned, and that getting their + # values is type stable when the varinfo is. + for k in keys(global_varinfo) + is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) + @test Turing.Inference.is_target_varname(ctx, k) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) + end + end + + # Check the type stability also in the dot_tilde pipeline. + for k in all_varnames + # The map(identity, ...) part is there to concretise the eltype. + subkeys = map( + identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) + ) + is_target = (k in target_vns) + @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) + end + end + + # Check that evaluate!! and the result it returns are type stable. + conditioned_model = DynamicPPL.contextualize(model, ctx) + _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( + conditioned_model, local_varinfo + ) + for k in keys(post_eval_varinfo) + @inferred post_eval_varinfo[k] + end + end + end +end + +@testset "Invalid Gibbs constructor" begin + # More samplers than varnames or vice versa + @test_throws ArgumentError Gibbs((@varname(s), @varname(m)), (NUTS(), NUTS(), NUTS())) + @test_throws ArgumentError Gibbs( + (@varname(s), @varname(m), @varname(x)), (NUTS(), NUTS()) + ) + # Invalid samplers + @test_throws ArgumentError Gibbs(@varname(s) => IS()) + @test_throws ArgumentError Gibbs(@varname(s) => Emcee(10, 2.0)) + @test_throws ArgumentError Gibbs( + @varname(s) => SGHMC(; learning_rate=0.01, momentum_decay=0.1) + ) + @test_throws ArgumentError Gibbs( + @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) + ) +end @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - # @testset "Deprecated Gibbs constructors" begin - # N = 10 - # @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) - # @test_deprecated s2 = Gibbs(PG(10, :s, :m)) - # @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - # @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - # @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - # @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - # @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) - # for s in (s1, s2, s3, s4, s5, s6, s7) - # @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" - # end - - # # Check that the samplers work despite using the deprecated constructor. - # sample(gdemo_default, s1, N) - # sample(gdemo_default, s2, N) - # sample(gdemo_default, s3, N) - # sample(gdemo_default, s4, N) - # sample(gdemo_default, s5, N) - # sample(gdemo_default, s6, N) - # sample(gdemo_default, s7, N) - - # g = Turing.Sampler(s3, gdemo_default) - # @test sample(gdemo_default, g, N) isa MCMCChains.Chains - # end - - # @testset "Gibbs constructors" begin - # # Create Gibbs samplers with various configurations and ways of passing the - # # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. - # N = 10 - # # Two variables being sampled by one sampler. - # s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) - # s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) - # # One variable per sampler, using the keyword arg interface. - # s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) - # # As above but using a Dict of VarNames. - # s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) - # # As above but different samplers. - # s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) - # s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) - # s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) - # # Multiple instnaces of the same sampler. This implements running, in this case, - # # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. - # s8 = begin - # hmc = HMC(0.1, 5; adtype=adbackend) - # pg = PG(10) - # vns = @varname(s) - # vnm = @varname(m) - # Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) - # end - # for s in (s1, s2, s3, s4, s5, s6, s7, s8) - # @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" - # end - - # sample(gdemo_default, s1, N) - # sample(gdemo_default, s2, N) - # sample(gdemo_default, s3, N) - # sample(gdemo_default, s4, N) - # sample(gdemo_default, s5, N) - # sample(gdemo_default, s6, N) - # sample(gdemo_default, s7, N) - # sample(gdemo_default, s8, N) - - # g = Turing.Sampler(s3, gdemo_default) - # @test sample(gdemo_default, g, N) isa MCMCChains.Chains - # end + @testset "Deprecated Gibbs constructors" begin + N = 10 + @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) + @test_deprecated s2 = Gibbs(PG(10, :s, :m)) + @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) + @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) + for s in (s1, s2, s3, s4, s5, s6, s7) + @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" + end + + # Check that the samplers work despite using the deprecated constructor. + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) + + g = Turing.Sampler(s3, gdemo_default) + @test sample(gdemo_default, g, N) isa MCMCChains.Chains + end + + @testset "Gibbs constructors" begin + # Create Gibbs samplers with various configurations and ways of passing the + # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. + N = 10 + # Two variables being sampled by one sampler. + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) + s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) + # One variable per sampler, using the keyword arg interface. + s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + # As above but using a Dict of VarNames. + s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + # As above but different samplers. + s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) + s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) + s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + # Multiple instnaces of the same sampler. This implements running, in this case, + # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. + s8 = begin + hmc = HMC(0.1, 5; adtype=adbackend) + pg = PG(10) + vns = @varname(s) + vnm = @varname(m) + Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) + end + for s in (s1, s2, s3, s4, s5, s6, s7, s8) + @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" + end + + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) + sample(gdemo_default, s8, N) + + g = Turing.Sampler(s3, gdemo_default) + @test sample(gdemo_default, g, N) isa MCMCChains.Chains + end @testset "gibbs inference" begin - # Random.seed!(100) - # alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) - # chain = sample(gdemo(1.5, 2.0), alg, 10_000) - # check_numerical(chain, [:m], [7 / 6]; atol=0.15) - # # Be more relaxed with the tolerance of the variance. - # check_numerical(chain, [:s], [49 / 24]; atol=0.35) - - # Random.seed!(100) - - # alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) - # chain = sample(gdemo(1.5, 2.0), alg, 10_000) - # check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - # alg = Gibbs(; s=CSMC(15), m=ESS()) - # chain = sample(gdemo(1.5, 2.0), alg, 10_000) - # check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - # alg = CSMC(15) - # chain = sample(gdemo(1.5, 2.0), alg, 10_000) - # check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - # Random.seed!(200) - # gibbs = Gibbs( - # (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), - # (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), - # ) - # chain = sample(MoGtest_default, gibbs, 10_000) - # check_MoGtest_default(chain; atol=0.15) - - # Random.seed!(200) - # for alg in [ - # # The new syntax for specifying a sampler to run twice for one variable. - # Gibbs( - # @varname(s) => MH(), - # @varname(s) => MH(), - # @varname(m) => HMC(0.2, 4; adtype=adbackend), - # ), - # Gibbs( - # @varname(s) => MH(), - # @varname(m) => HMC(0.2, 4; adtype=adbackend), - # @varname(m) => HMC(0.2, 4; adtype=adbackend), - # ), - # ] - # chain = sample(gdemo(1.5, 2.0), alg, 10_000) - # check_gdemo(chain; atol=0.15) - # end + Random.seed!(100) + alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:m], [7 / 6]; atol=0.15) + # Be more relaxed with the tolerance of the variance. + check_numerical(chain, [:s], [49 / 24]; atol=0.35) + + Random.seed!(100) + + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + + alg = Gibbs(; s=CSMC(15), m=ESS()) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + + alg = CSMC(15) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + + Random.seed!(200) + gibbs = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + ) + chain = sample(MoGtest_default, gibbs, 10_000) + check_MoGtest_default(chain; atol=0.15) + + Random.seed!(200) + for alg in [ + # The new syntax for specifying a sampler to run twice for one variable. + Gibbs( + @varname(s) => MH(), + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + ), + Gibbs( + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + ), + ] + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_gdemo(chain; atol=0.15) + end end - # @testset "transitions" begin - # @model function gdemo_copy() - # s ~ InverseGamma(2, 3) - # m ~ Normal(0, sqrt(s)) - # 1.5 ~ Normal(m, sqrt(s)) - # 2.0 ~ Normal(m, sqrt(s)) - # return s, m - # end - # model = gdemo_copy() - - # @nospecialize function AbstractMCMC.bundle_samples( - # samples::Vector, - # ::typeof(model), - # ::Turing.Sampler{<:Gibbs}, - # state, - # ::Type{MCMCChains.Chains}; - # kwargs..., - # ) - # samples isa Vector{<:Inference.Transition} || error("incorrect transitions") - # return nothing - # end - - # function callback(rng, model, sampler, sample, state, i; kwargs...) - # sample isa Inference.Transition || error("incorrect sample") - # return nothing - # end - - # alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) - # sample(model, alg, 100; callback=callback) - # end - - # @testset "dynamic model" begin - # # TODO(mhauru) We should check that the results of the sampling are correct. - # # Currently we just check that this doesn't crash. - # @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} - # N = length(y) - # rpm = DirichletProcess(alpha) - - # z = zeros(Int, N) - # cluster_counts = zeros(Int, N) - # fill!(cluster_counts, 0) - - # for i in 1:N - # z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts) - # cluster_counts[z[i]] += 1 - # end - - # Kmax = findlast(!iszero, cluster_counts) - # m = M(undef, Kmax) - # for k in 1:Kmax - # m[k] ~ Normal(1.0, 1.0) - # end - # end - # model = imm(Random.randn(100), 1.0) - # # https://github.com/TuringLang/Turing.jl/issues/1725 - # # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); - # sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) - # end - - # @testset "dynamic model with dot tilde" begin - # @model function dynamic_model_with_dot_tilde( - # num_zs=10, ::Type{M}=Vector{Float64} - # ) where {M} - # z = M(undef, num_zs) - # z .~ Poisson(1.0) - # num_ms = sum(z) - # m = M(undef, num_ms) - # return m .~ Normal(1.0, 1.0) - # end - # model = dynamic_model_with_dot_tilde() - # # TODO(mhauru) This is broken because of - # # https://github.com/TuringLang/DynamicPPL.jl/issues/700. - # @test_broken ( - # sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); - # true - # ) - # end + @testset "transitions" begin + @model function gdemo_copy() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + return s, m + end + model = gdemo_copy() + + @nospecialize function AbstractMCMC.bundle_samples( + samples::Vector, + ::typeof(model), + ::Turing.Sampler{<:Gibbs}, + state, + ::Type{MCMCChains.Chains}; + kwargs..., + ) + samples isa Vector{<:Inference.Transition} || error("incorrect transitions") + return nothing + end + + function callback(rng, model, sampler, sample, state, i; kwargs...) + sample isa Inference.Transition || error("incorrect sample") + return nothing + end + + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + sample(model, alg, 100; callback=callback) + end + + @testset "dynamic model" begin + @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} + N = length(y) + rpm = DirichletProcess(alpha) + + z = zeros(Int, N) + cluster_counts = zeros(Int, N) + fill!(cluster_counts, 0) + + for i in 1:N + z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts) + cluster_counts[z[i]] += 1 + end + + Kmax = findlast(!iszero, cluster_counts) + m = M(undef, Kmax) + for k in 1:Kmax + m[k] ~ Normal(1.0, 1.0) + end + end + num_zs = 100 + num_samples = 10_000 + model = imm(Random.randn(num_zs), 1.0) + # https://github.com/TuringLang/Turing.jl/issues/1725 + # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); + chn = sample( + model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), num_samples + ) + # The number of m variables that have a non-zero value in a sample. + num_ms = count(ismissing.(Array(chn[:, (num_zs + 1):end, 1])); dims=2) + # The below are regression tests. The values we are comparing against are from + # running the above model on the "old" Gibbs sampler that was in place still on + # 2024-11-20. The model was run 5 times with 10_000 samples each time. The values + # to compare to are the mean of those 5 runs, atol is roughly estimated from the + # standard deviation of those 5 runs. + # TODO(mhauru) Could we do something smarter here? Maybe a dynamic model for which + # the posterior is analytically known? Doing 10_000 samples to run the test suite + # is not ideal + @test isapprox(mean(num_ms), 8.6087; atol=0.5) + @test isapprox(std(num_ms), 1.8865; atol=0.01) + end + + @testset "dynamic model with dot tilde" begin + @model function dynamic_model_with_dot_tilde( + num_zs=10, ::Type{M}=Vector{Float64} + ) where {M} + z = M(undef, num_zs) + z .~ Poisson(1.0) + num_ms = sum(z) + m = M(undef, num_ms) + return m .~ Normal(1.0, 1.0) + end + model = dynamic_model_with_dot_tilde() + # TODO(mhauru) This is broken because of + # https://github.com/TuringLang/DynamicPPL.jl/issues/700. + @test_broken ( + sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); + true + ) + end @testset "Demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS From 41b25f8884f9804e7f501c72ca0e34b1809de668 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 20 Nov 2024 14:05:08 +0000 Subject: [PATCH 51/70] Relax a test tolerance --- test/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 91baf27731..1d31464327 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -317,7 +317,7 @@ end # the posterior is analytically known? Doing 10_000 samples to run the test suite # is not ideal @test isapprox(mean(num_ms), 8.6087; atol=0.5) - @test isapprox(std(num_ms), 1.8865; atol=0.01) + @test isapprox(std(num_ms), 1.8865; atol=0.02) end @testset "dynamic model with dot tilde" begin From 583d297907f630872796fa1f4e641b109e02877a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 26 Nov 2024 15:52:50 +0000 Subject: [PATCH 52/70] Add a Gibbs test for dynamic model with ESS --- test/mcmc/gibbs.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 1d31464327..38b9341de2 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -320,6 +320,28 @@ end @test isapprox(std(num_ms), 1.8865; atol=0.02) end + # The below test used to sample incorrectly before + # https://github.com/TuringLang/Turing.jl/pull/2328 + @testset "dynamic model with ESS" begin + @model function dynamic_model_for_ess() + b ~ Bernoulli() + x_length = b ? 1 : 2 + x = Vector{Float64}(undef, x_length) + for i in 1:x_length + x[i] ~ Normal(i, 1.0) + end + end + + m = dynamic_model_for_ess() + chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) + means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) + stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) + for vn in keys(means) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + end + end + @testset "dynamic model with dot tilde" begin @model function dynamic_model_with_dot_tilde( num_zs=10, ::Type{M}=Vector{Float64} From cc9510c35bf7b6050fe9079e91dd0d079ae6fee7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 26 Nov 2024 16:16:10 +0000 Subject: [PATCH 53/70] Use ESS in Gibbs DEMO_MODELS tests --- test/mcmc/gibbs.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 38b9341de2..480b44f859 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -367,6 +367,7 @@ end samplers = [ Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()), Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)), + Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => ESS()), ] if !has_dot_assume(model) @@ -412,10 +413,7 @@ end initial_params = fill(initial_params, num_chains) # Sampler to use for Gibbs components. - sampler_inner = HMC(0.1, 32) - sampler = Turing.Gibbs( - @varname(s) => sampler_inner, @varname(m) => sampler_inner - ) + sampler = Turing.Gibbs(@varname(s) => HMC(0.1, 32), @varname(m) => ESS()) Random.seed!(42) chain = sample( model, From 310bee9b4040a894a57d1afea98384b4f83abc32 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 27 Nov 2024 18:36:30 +0000 Subject: [PATCH 54/70] Add Gibbs component call order test --- test/mcmc/gibbs.jl | 121 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 480b44f859..45b986d6c7 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -130,6 +130,127 @@ end ) end +# Test that the samplers are being called in the correct order, on the correct target +# variables. +@testset "Sampler call order" begin + # A wrapper around inference algorithms to allow intercepting the dispatch cascade to + # collect testing information. + struct AlgWrapper{Alg<:Inference.InferenceAlgorithm} <: Inference.InferenceAlgorithm + inner::Alg + end + + unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) = + DynamicPPL.Sampler(sampler.alg.inner, sampler.selector) + + # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm. + # They all just propagate the call to the inner algorithm. + Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner) + Inference.drop_space(wrap::AlgWrapper) = AlgWrapper(Inference.drop_space(wrap.inner)) + function Inference.setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgWrapper}, + state, + params::Turing.AbstractVarInfo, + ) + return Inference.setparams_varinfo!!(model, unwrap_sampler(sampler), state, params) + end + + function target_vns(::Inference.GibbsContext{VNs}) where {VNs} + return VNs + end + + # targets_and_algs will be a list of tuples, where the first element is the target_vns + # of a component sampler, and the second element is the component sampler itself. + # It is modified by the capture_targets_and_algs function. + targets_and_algs = Any[] + + function capture_targets_and_algs(sampler, context) + if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf() + return nothing + end + if context isa Inference.GibbsContext + push!(targets_and_algs, (target_vns(context), sampler)) + end + return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context)) + end + + # The methods that capture testing information for us. + function Turing.AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgWrapper}, + args...; + kwargs..., + ) + capture_targets_and_algs(sampler.alg.inner, model.context) + return Turing.AbstractMCMC.step( + rng, model, unwrap_sampler(sampler), args...; kwargs... + ) + end + + function Turing.DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgWrapper}, + args...; + kwargs..., + ) + capture_targets_and_algs(sampler.alg.inner, model.context) + return Turing.DynamicPPL.initialstep( + rng, model, unwrap_sampler(sampler), args...; kwargs... + ) + end + + # A test model that includes several different kinds of tilde syntax. + @model function test_model(val, ::Type{M}=Vector{Float64}) where {M} + s ~ Normal(0.1, 0.2) + m ~ Poisson() + val ~ Normal(s, 1) + 1.0 ~ Normal(s + m, 1) + + n := m + 1 + xs = M(undef, n) + for i in eachindex(xs) + xs[i] ~ Beta(0.5, 0.5) + end + + ys = M(undef, 2) + ys .~ Beta(1.0, 1.0) + return sum(xs), sum(ys), n + end + + mh = MH() + pg = PG(10) + hmc = HMC(0.01, 4) + nuts = NUTS() + # Sample with all sorts of combinations of samplers and targets. + sampler = Gibbs( + (@varname(s),) => AlgWrapper(mh), + (@varname(s), @varname(m)) => AlgWrapper(mh), + (@varname(m),) => AlgWrapper(pg), + (@varname(xs),) => AlgWrapper(hmc), + (@varname(ys),) => AlgWrapper(nuts), + (@varname(ys),) => AlgWrapper(nuts), + (@varname(xs), @varname(ys)) => AlgWrapper(hmc), + (@varname(s),) => AlgWrapper(mh), + ) + chain = sample(test_model(-1), sampler, 2) + + expected_targets_and_algs_per_iteration = [ + ((:s,), mh), + ((:s, :m), mh), + ((:m,), pg), + ((:xs,), hmc), + ((:ys,), nuts), + ((:ys,), nuts), + ((:xs, :ys), hmc), + ((:s,), mh), + ] + @test targets_and_algs == vcat( + expected_targets_and_algs_per_iteration, expected_targets_and_algs_per_iteration + ) +end + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "Deprecated Gibbs constructors" begin N = 10 From 519ff0251207ccef9d6d48be03f0f7e7ec67038d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 12:37:42 +0000 Subject: [PATCH 55/70] Fix Gibbs linking bug, add tests --- src/mcmc/gibbs.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ test/mcmc/gibbs.jl | 40 ++++++++++++++++++++++++---------------- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index a460791a6d..07c231e17c 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -543,6 +543,49 @@ function setparams_varinfo!!( return PGState(params, state.rng) end +""" + match_linking!!(varinfo_local, prev_state_local, model) + +Make sure the linked/invlinked status of varinfo_local matches that of the previous +state for this sampler. This is relevant when multilple samplers are sampling the same +variables, and one might need it to be linked while the other doesn't. +""" +function match_linking!!(varinfo_local, prev_state_local, model) + prev_varinfo_local = varinfo(prev_state_local) + was_linked = DynamicPPL.istrans(prev_varinfo_local) + is_linked = DynamicPPL.istrans(varinfo_local) + if was_linked && !is_linked + varinfo_local = DynamicPPL.link!!(varinfo_local, model) + elseif !was_linked && is_linked + varinfo_local = DynamicPPL.invlink!!(varinfo_local, model) + end + # TODO(mhauru) The above might run into trouble if some variables are linked and others + # are not. `istrans(varinfo)` returns an `all` over the individual variables. This could + # especially be a problem with dynamic models, where new variables may get introduced, + # but also in cases where component samplers have partial overlap in their target + # variables. The below is how I would like to implement this, but DynamicPPL at this + # time does not support linking individual variables selected by `VarName`. It soon + # should though, so come back to this. + # prev_links_dict = Dict(vn => DynamicPPL.istrans(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local)) + # any_linked = any(values(prev_links_dict)) + # for vn in keys(varinfo_local) + # was_linked = if haskey(prev_varinfo_local, vn) + # prev_links_dict[vn] + # else + # # If the old state didn't have this variable, we assume it was linked if _any_ + # # of the variables of the old state were linked. + # any_linked + # end + # is_linked = DynamicPPL.istrans(varinfo_local, vn) + # if was_linked && !is_linked + # varinfo_local = DynamicPPL.invlink!!(varinfo_local, vn) + # elseif !was_linked && is_linked + # varinfo_local = DynamicPPL.link!!(varinfo_local, vn) + # end + # end + return varinfo_local +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -555,6 +598,7 @@ function gibbs_step_inner( # Construct the conditional model and the varinfo that this sampler should use. model_local, context_local = make_conditional(model, varnames_local, global_vi) varinfo_local = subset(global_vi, varnames_local) + varinfo_local = match_linking!!(varinfo_local, state_local, model) # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not # sampled by other samplers, we don't need to `setparams`, but could rather simply diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 45b986d6c7..d1e9505e54 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -350,22 +350,30 @@ end check_MoGtest_default(chain; atol=0.15) Random.seed!(200) - for alg in [ - # The new syntax for specifying a sampler to run twice for one variable. - Gibbs( - @varname(s) => MH(), - @varname(s) => MH(), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - ), - Gibbs( - @varname(s) => MH(), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - ), - ] - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_gdemo(chain; atol=0.15) - end + # Test samplers that are run multiple times, or have overlapping targets. + alg = Gibbs( + @varname(s) => MH(), + (@varname(s), @varname(m)) => MH(), + @varname(m) => ESS(), + @varname(s) => MH(), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), + ) + chain = sample(gdemo(1.5, 2.0), alg, 300) + check_gdemo(chain; atol=0.15) + + Random.seed!(200) + gibbs = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(z1), @varname(z2)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(z3), @varname(z4)) => PG(15), + (@varname(mu1)) => ESS(), + (@varname(mu2)) => ESS(), + (@varname(z1), @varname(z2)) => PG(15), + ) + chain = sample(MoGtest_default, gibbs, 300) + check_MoGtest_default(chain; atol=0.15) end @testset "transitions" begin From f9ed562f392930b5acbf9ca85c1ebf438a40a91f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 13:52:02 +0000 Subject: [PATCH 56/70] Make Gibbs constructor more flexible --- src/mcmc/gibbs.jl | 18 ++++++++++++++---- test/mcmc/gibbs.jl | 6 +++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 07c231e17c..85e9dbb139 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -320,21 +320,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm end end +to_varname(vn::VarName) = vn +to_varname(s::Symbol) = VarName{s}() +# Any other value is assumed to be an iterable. +to_varname(t) = map(to_varname, collect(t)) + # NamedTuple Gibbs(; algs...) = Gibbs(NamedTuple(algs)) function Gibbs(algs::NamedTuple) return Gibbs( - map(s -> VarName{s}(), keys(algs)), - map(wrap_algorithm_maybe ∘ drop_space, values(algs)), + map(to_varname, keys(algs)), map(wrap_algorithm_maybe ∘ drop_space, values(algs)) ) end # AbstractDict function Gibbs(algs::AbstractDict) - return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe ∘ drop_space, values(algs))) + return Gibbs( + map(to_varname, collect(keys(algs))), + map(wrap_algorithm_maybe ∘ drop_space, values(algs)), + ) end function Gibbs(algs::Pair...) - return Gibbs(map(first, algs), map(wrap_algorithm_maybe ∘ drop_space, map(last, algs))) + return Gibbs( + map(to_varname ∘ first, algs), map(wrap_algorithm_maybe ∘ drop_space ∘ last, algs) + ) end # The below two constructors only provide backwards compatibility with the constructor of @@ -383,6 +392,7 @@ end _maybevec(x) = vec(x) # assume it's iterable _maybevec(x::Tuple) = [x...] _maybevec(x::VarName) = [x] +_maybevec(x::Symbol) = [x] varinfo(state::GibbsState) = state.vi diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index d1e9505e54..6e03c48f40 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -284,15 +284,15 @@ end N = 10 # Two variables being sampled by one sampler. s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) - s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) + s2 = Gibbs((@varname(s), :m) => PG(10)) # One variable per sampler, using the keyword arg interface. s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) # As above but using a Dict of VarNames. s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) - # As above but different samplers. + # As above but different samplers and using kwargs. s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) - s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + s7 = Gibbs(Dict((:s, @varname(m)) => PG(10))) # Multiple instnaces of the same sampler. This implements running, in this case, # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. s8 = begin From 7b0d12be75c7cc164258f4e3cd1f7c12b4707603 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 15:01:23 +0000 Subject: [PATCH 57/70] Introduce RepeatSampler --- HISTORY.md | 2 +- src/Turing.jl | 1 + src/mcmc/Inference.jl | 13 ++++++-- src/mcmc/gibbs.jl | 4 ++- src/mcmc/repeat_sampler.jl | 62 +++++++++++++++++++++++++++++++++++++ test/mcmc/gibbs.jl | 24 ++++++++++++-- test/mcmc/repeat_sampler.jl | 47 ++++++++++++++++++++++++++++ test/runtests.jl | 1 + 8 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 src/mcmc/repeat_sampler.jl create mode 100644 test/mcmc/repeat_sampler.jl diff --git a/HISTORY.md b/HISTORY.md index 7f9b307c6d..ff50fb7795 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -12,7 +12,7 @@ may be accidental breakage that we haven't anticipated. Please report any you fi The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. -Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`. +Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`. # Release 0.35.0 diff --git a/src/Turing.jl b/src/Turing.jl index 513d0d239c..f1c5e407ab 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -95,6 +95,7 @@ export @model, # modelling SMC, CSMC, PG, + RepeatSampler, vi, # variational inference ADVI, sample, # inference diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 6765e7b72c..e984d0a2d0 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -74,6 +74,7 @@ export InferenceAlgorithm, SMC, CSMC, PG, + RepeatSampler, Prior, assume, dot_assume, @@ -100,6 +101,12 @@ Return an `InferenceAlgorithm` like `alg`, but with all space information remove """ function drop_space end +function drop_space(sampler::Sampler) + return Sampler(drop_space(sampler.alg), sampler.selector) +end + +include("repeat_sampler.jl") + """ ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} @@ -348,7 +355,7 @@ end function AbstractMCMC.sample( rng::AbstractRNG, model::AbstractModel, - sampler::Sampler{<:InferenceAlgorithm}, + sampler::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; @@ -460,7 +467,7 @@ getlogevidence(transitions, sampler, state) = missing function AbstractMCMC.bundle_samples( ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, chain_type::Type{MCMCChains.Chains}; save_state=false, @@ -523,7 +530,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, chain_type::Type{Vector{NamedTuple}}; kwargs..., diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 85e9dbb139..8bed40949d 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -15,6 +15,8 @@ isgibbscomponent(::NUTS) = true isgibbscomponent(::MH) = true isgibbscomponent(::PG) = true +isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler) + isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) isgibbscomponent(::AdvancedHMC.HMC) = true isgibbscomponent(::AdvancedMH.MetropolisHastings) = true @@ -364,7 +366,7 @@ function Gibbs(algs::InferenceAlgorithm...) "`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " * "Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " * "counts for different subsamplers, use e.g. " * - "`Gibbs(@varname(x) => NUTS(), @varname(x) => NUTS(), @varname(y) => MH())`" + "`Gibbs(@varname(x) => RepeatSampler(NUTS(), 2), @varname(y) => MH())`" ) Base.depwarn(msg, :Gibbs) return Gibbs(varnames, map(wrap_algorithm_maybe ∘ drop_space, algs)) diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl new file mode 100644 index 0000000000..a3e38f46a9 --- /dev/null +++ b/src/mcmc/repeat_sampler.jl @@ -0,0 +1,62 @@ +""" + RepeatSampler <: AbstractMCMC.AbstractSampler + +A `RepeatSampler` is a container for a sampler and a number of times to repeat it. + +# Fields +$(FIELDS) + +# Examples +```julia +repeated_sampler = RepeatSampler(sampler, 10) +AbstractMCMC.step(rng, model, repeated_sampler) # take 10 steps of `sampler` +``` +""" +struct RepeatSampler{S<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSampler + "The sampler to repeat" + sampler::S + "The number of times to repeat the sampler" + num_repeat::Int + + function RepeatSampler(sampler::S, num_repeat::Int) where {S} + @assert num_repeat > 0 + return new{S}(sampler, num_repeat) + end +end + +function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int) + return RepeatSampler(Sampler(alg), num_repeat) +end + +drop_space(rs::RepeatSampler) = RepeatSampler(drop_space(rs.sampler), rs.num_repeat) +getADType(spl::RepeatSampler) = getADType(spl.sampler) +DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler) +DynamicPPL.getspace(spl::RepeatSampler) = getspace(spl.sampler) +DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler) + +function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params) + return setparams_varinfo!!(model, sampler.sampler, state, params) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatSampler; + kwargs..., +) + return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatSampler, + state; + kwargs..., +) + transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...) + for _ in 2:(sampler.num_repeat) + transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...) + end + return transition, state +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 6e03c48f40..59d694ac82 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -251,6 +251,18 @@ end ) end +@testset "Equivalence of RepeatSampler and repeating Sampler" begin + sampler1 = Gibbs(@varname(s) => RepeatSampler(MH(), 3), @varname(m) => ESS()) + sampler2 = Gibbs( + @varname(s) => MH(), @varname(s) => MH(), @varname(s) => MH(), @varname(m) => ESS() + ) + Random.seed!(23) + chain1 = sample(gdemo_default, sampler1, 10) + Random.seed!(23) + chain2 = sample(gdemo_default, sampler1, 10) + @test chain1.value == chain2.value +end + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @testset "Deprecated Gibbs constructors" begin N = 10 @@ -302,7 +314,12 @@ end vnm = @varname(m) Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) end - for s in (s1, s2, s3, s4, s5, s6, s7, s8) + # Same thing but using RepeatSampler. + s9 = Gibbs( + @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), + @varname(m) => RepeatSampler(PG(10), 2), + ) + for s in (s1, s2, s3, s4, s5, s6, s7, s8, s9) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end @@ -314,6 +331,7 @@ end sample(gdemo_default, s6, N) sample(gdemo_default, s7, N) sample(gdemo_default, s8, N) + sample(gdemo_default, s9, N) g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains @@ -355,7 +373,7 @@ end @varname(s) => MH(), (@varname(s), @varname(m)) => MH(), @varname(m) => ESS(), - @varname(s) => MH(), + @varname(s) => RepeatSampler(MH(), 3), @varname(m) => HMC(0.2, 4; adtype=adbackend), (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), ) @@ -367,7 +385,7 @@ end (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), (@varname(z1), @varname(z2)) => PG(15), (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), - (@varname(z3), @varname(z4)) => PG(15), + (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), (@varname(mu1)) => ESS(), (@varname(mu2)) => ESS(), (@varname(z1), @varname(z2)) => PG(15), diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl new file mode 100644 index 0000000000..fa2f69bfef --- /dev/null +++ b/test/mcmc/repeat_sampler.jl @@ -0,0 +1,47 @@ +module HMCTests + +using ..Models: gdemo_default +using ..ADUtils: ADTypeCheckContext +using ..NumericalTests: check_gdemo, check_numerical +import ..ADUtils +using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample +import DynamicPPL +using DynamicPPL: Sampler +import ForwardDiff +using HypothesisTests: ApproximateTwoSampleKSTest, pvalue +import ReverseDiff +using LinearAlgebra: I, dot, vec +import Random +using StableRNGs: StableRNG +using StatsFuns: logistic +import Mooncake +using Test: @test, @test_logs, @testset, @test_throws +using Turing + +# RepeatedSampler only really makes sense as a component sampler of Gibbs. +# Here we just check that running it by itself is equivalent to thinning. +@testset "RepeatedSampler" begin + num_repeats = 17 + num_samples = 10 + num_chains = 2 + + rng = StableRNG(0) + for sampler in [MH(), Sampler(HMC(0.01, 4))] + chn1 = sample( + copy(rng), + gdemo_default, + sampler, + MCMCThreads(), + num_samples, + num_chains; + thinning=num_repeats, + ) + repeat_sampler = RepeatSampler(sampler, num_repeats) + chn2 = sample( + copy(rng), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, num_chains + ) + @test chn1.value == chn2.value + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index f827f6fc54..c63660d164 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,6 +60,7 @@ end @timeit_include("mcmc/abstractmcmc.jl") @timeit_include("mcmc/mh.jl") @timeit_include("ext/dynamichmc.jl") + @timeit_include("mcmc/repeat_sampler.jl") end @testset "variational algorithms" begin From 4beb463fae302607ee635ff9188569d6dd736284 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 15:58:48 +0000 Subject: [PATCH 58/70] Switch gold standard sample Gibbs test back to HMC I tried using ESS instead, because I thought it would test behavior a bit more broadly, given similarities between HMC and NUTS. It worked locally, but the KS test fails in one or two cases on CI. --- test/mcmc/gibbs.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 59d694ac82..65ca6dc82a 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -560,7 +560,8 @@ end initial_params = fill(initial_params, num_chains) # Sampler to use for Gibbs components. - sampler = Turing.Gibbs(@varname(s) => HMC(0.1, 32), @varname(m) => ESS()) + hmc = HMC(0.1, 32) + sampler = Turing.Gibbs(@varname(s) => hmc, @varname(m) => hmc) Random.seed!(42) chain = sample( model, From 104144980ea438bac55188fb7a72d58412fad501 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 16:09:46 +0000 Subject: [PATCH 59/70] Clean up RepeatSamplerTests preamble --- test/mcmc/repeat_sampler.jl | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index fa2f69bfef..3519f75005 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -1,21 +1,9 @@ -module HMCTests +module RepeatSamplerTests using ..Models: gdemo_default -using ..ADUtils: ADTypeCheckContext -using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils -using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample -import DynamicPPL using DynamicPPL: Sampler -import ForwardDiff -using HypothesisTests: ApproximateTwoSampleKSTest, pvalue -import ReverseDiff -using LinearAlgebra: I, dot, vec -import Random using StableRNGs: StableRNG -using StatsFuns: logistic -import Mooncake -using Test: @test, @test_logs, @testset, @test_throws +using Test: @test, @testset using Turing # RepeatedSampler only really makes sense as a component sampler of Gibbs. From d17e6ed1f0b9be25b925a3f4c6405e8862759a6a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 16:21:38 +0000 Subject: [PATCH 60/70] Fix RepeatSampler in Gibbs bug --- src/mcmc/gibbs.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 8bed40949d..08601f94f4 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -292,6 +292,9 @@ wrap_algorithm_maybe(x) = x function wrap_algorithm_maybe(x::DynamicPPL.Sampler) return DynamicPPL.Sampler(x.alg, DynamicPPL.Selector(0)) end +function wrap_algorithm_maybe(x::RepeatSampler) + return RepeatSampler(wrap_algorithm_maybe(x.sampler), x.num_repeat) +end wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) """ From d22df392f259c070f31e8c6bd62ba7c253636462 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 16:23:54 +0000 Subject: [PATCH 61/70] Rename a function in Gibbs --- src/mcmc/gibbs.jl | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 08601f94f4..d1ce5451ee 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -285,17 +285,16 @@ function make_conditional( return DynamicPPL.contextualize(model, gibbs_context), gibbs_context_inner end -wrap_algorithm_maybe(x) = x -# All samplers are given the same Selector, so that they will also sample all variables +# All samplers are given the same Selector, so that they will sample all variables # given to them by the Gibbs sampler. This avoids conflicts between the new and the old way # of choosing which sampler to use. -function wrap_algorithm_maybe(x::DynamicPPL.Sampler) +function set_selector(x::DynamicPPL.Sampler) return DynamicPPL.Sampler(x.alg, DynamicPPL.Selector(0)) end -function wrap_algorithm_maybe(x::RepeatSampler) - return RepeatSampler(wrap_algorithm_maybe(x.sampler), x.num_repeat) +function set_selector(x::RepeatSampler) + return RepeatSampler(set_selector(x.sampler), x.num_repeat) end -wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) +set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) """ Gibbs @@ -333,22 +332,17 @@ to_varname(t) = map(to_varname, collect(t)) # NamedTuple Gibbs(; algs...) = Gibbs(NamedTuple(algs)) function Gibbs(algs::NamedTuple) - return Gibbs( - map(to_varname, keys(algs)), map(wrap_algorithm_maybe ∘ drop_space, values(algs)) - ) + return Gibbs(map(to_varname, keys(algs)), map(set_selector ∘ drop_space, values(algs))) end # AbstractDict function Gibbs(algs::AbstractDict) return Gibbs( - map(to_varname, collect(keys(algs))), - map(wrap_algorithm_maybe ∘ drop_space, values(algs)), + map(to_varname, collect(keys(algs))), map(set_selector ∘ drop_space, values(algs)) ) end function Gibbs(algs::Pair...) - return Gibbs( - map(to_varname ∘ first, algs), map(wrap_algorithm_maybe ∘ drop_space ∘ last, algs) - ) + return Gibbs(map(to_varname ∘ first, algs), map(set_selector ∘ drop_space ∘ last, algs)) end # The below two constructors only provide backwards compatibility with the constructor of @@ -372,7 +366,7 @@ function Gibbs(algs::InferenceAlgorithm...) "`Gibbs(@varname(x) => RepeatSampler(NUTS(), 2), @varname(y) => MH())`" ) Base.depwarn(msg, :Gibbs) - return Gibbs(varnames, map(wrap_algorithm_maybe ∘ drop_space, algs)) + return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) end function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) From 2025ef94ce2f6d720497c7443f300dcd02878825 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 28 Nov 2024 17:01:11 +0000 Subject: [PATCH 62/70] Test HMCDA in Gibbs tests --- test/mcmc/gibbs.jl | 4 ++-- test/mcmc/repeat_sampler.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 65ca6dc82a..7aa8e9a263 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -302,7 +302,7 @@ end # As above but using a Dict of VarNames. s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) # As above but different samplers and using kwargs. - s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) + s5 = Gibbs(; s=CSMC(3), m=HMCDA(200, 0.65, 0.15; adtype=adbackend)) s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) s7 = Gibbs(Dict((:s, @varname(m)) => PG(10))) # Multiple instnaces of the same sampler. This implements running, in this case, @@ -347,7 +347,7 @@ end Random.seed!(100) - alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index 3519f75005..7328d1168c 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -6,9 +6,9 @@ using StableRNGs: StableRNG using Test: @test, @testset using Turing -# RepeatedSampler only really makes sense as a component sampler of Gibbs. +# RepeatSampler only really makes sense as a component sampler of Gibbs. # Here we just check that running it by itself is equivalent to thinning. -@testset "RepeatedSampler" begin +@testset "RepeatSampler" begin num_repeats = 17 num_samples = 10 num_chains = 2 From 38ac128afdab19a0f76ac58a7bd8953e3489459a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 29 Nov 2024 09:57:43 +0000 Subject: [PATCH 63/70] Simplify is_target_varname --- src/mcmc/gibbs.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d1ce5451ee..3993552ce5 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -129,11 +129,7 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa return map(Base.Fix1(get_conditioned_gibbs, context), vns) end -function is_target_varname(context::GibbsContext, vn::VarName) - return is_target_varname(context, DynamicPPL.getsym(vn)) -end - -is_target_varname(::GibbsContext{T}, vn_symbol::Symbol) where {T} = vn_symbol in T +is_target_varname(::GibbsContext{VNs}, ::VarName{sym}) where {VNs,sym} = sym in VNs function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) From 7d6a98333acfbdf0c62811f4ed4ab6fe1cec35e7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 29 Nov 2024 10:00:23 +0000 Subject: [PATCH 64/70] Add suggestions from code review Co-authored-by: Tor Erlend Fjelde --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 3993552ce5..085cb82d85 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -509,7 +509,7 @@ function setparams_varinfo!!( # update its logprob. To do this, we have to call evaluate!! with the sampler, rather # than just a context, because ESS is peculiar in how it uses LikelihoodContext for # some variables and DefaultContext for others. - return last(DynamicPPL.evaluate!!(model, params, sampler)) + return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler))) end function setparams_varinfo!!( From 39c2f2137f69c5592e3cbce1920a5e2b117c56bf Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 29 Nov 2024 10:18:38 +0000 Subject: [PATCH 65/70] Add a couple of issue references --- src/mcmc/gibbs.jl | 1 + test/mcmc/gibbs.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 085cb82d85..0f2c78ebe8 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -571,6 +571,7 @@ function match_linking!!(varinfo_local, prev_state_local, model) # variables. The below is how I would like to implement this, but DynamicPPL at this # time does not support linking individual variables selected by `VarName`. It soon # should though, so come back to this. + # Issue ref: https://github.com/TuringLang/Turing.jl/issues/2401 # prev_links_dict = Dict(vn => DynamicPPL.istrans(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local)) # any_linked = any(values(prev_links_dict)) # for vn in keys(varinfo_local) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 7aa8e9a263..d0a3c67a42 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -463,6 +463,7 @@ end # TODO(mhauru) Could we do something smarter here? Maybe a dynamic model for which # the posterior is analytically known? Doing 10_000 samples to run the test suite # is not ideal + # Issue ref: https://github.com/TuringLang/Turing.jl/issues/2402 @test isapprox(mean(num_ms), 8.6087; atol=0.5) @test isapprox(std(num_ms), 1.8865; atol=0.02) end From ff0959107c8d129fc2d49448d36f0368c70c53a8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 29 Nov 2024 11:46:29 +0000 Subject: [PATCH 66/70] Restructure Gibbs inference tests and reduce iteration counts --- test/mcmc/gibbs.jl | 121 ++++++++++++++++++++++++--------------------- 1 file changed, 66 insertions(+), 55 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index d0a3c67a42..4da9fa6312 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -337,61 +337,72 @@ end @test sample(gdemo_default, g, N) isa MCMCChains.Chains end - @testset "gibbs inference" begin - Random.seed!(100) - alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:m], [7 / 6]; atol=0.15) - # Be more relaxed with the tolerance of the variance. - check_numerical(chain, [:s], [49 / 24]; atol=0.35) - - Random.seed!(100) - - alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - alg = Gibbs(; s=CSMC(15), m=ESS()) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - alg = CSMC(15) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - - Random.seed!(200) - gibbs = Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), - ) - chain = sample(MoGtest_default, gibbs, 10_000) - check_MoGtest_default(chain; atol=0.15) - - Random.seed!(200) - # Test samplers that are run multiple times, or have overlapping targets. - alg = Gibbs( - @varname(s) => MH(), - (@varname(s), @varname(m)) => MH(), - @varname(m) => ESS(), - @varname(s) => RepeatSampler(MH(), 3), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), - ) - chain = sample(gdemo(1.5, 2.0), alg, 300) - check_gdemo(chain; atol=0.15) - - Random.seed!(200) - gibbs = Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), - (@varname(z1), @varname(z2)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), - (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), - (@varname(mu1)) => ESS(), - (@varname(mu2)) => ESS(), - (@varname(z1), @varname(z2)) => PG(15), - ) - chain = sample(MoGtest_default, gibbs, 300) - check_MoGtest_default(chain; atol=0.15) + # Test various combinations of samplers against models for which we know the analytical + # posterior mean. + @testset "Gibbs inference" begin + @testset "CSMC and HMC on gdemo" begin + alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) + chain = sample(gdemo(1.5, 2.0), alg, 3_000) + check_numerical(chain, [:m], [7 / 6]; atol=0.15) + # Be more relaxed with the tolerance of the variance. + check_numerical(chain, [:s], [49 / 24]; atol=0.35) + end + + @testset "MH and HMCDA on gdemo" begin + alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend)) + chain = sample(gdemo(1.5, 2.0), alg, 3_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + end + + @testset "CSMC and ESS on gdemo" begin + alg = Gibbs(; s=CSMC(15), m=ESS()) + chain = sample(gdemo(1.5, 2.0), alg, 3_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + end + + # TODO(mhauru) Why is this in the Gibbs test suite? + @testset "CSMC on gdemo" begin + alg = CSMC(15) + chain = sample(gdemo(1.5, 2.0), alg, 4_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + end + + @testset "PG and HMC on MoGtest_default" begin + gibbs = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + ) + chain = sample(MoGtest_default, gibbs, 2_000) + check_MoGtest_default(chain; atol=0.15) + end + + @testset "Multiple overlapping samplers on gdemo" begin + # Test samplers that are run multiple times, or have overlapping targets. + alg = Gibbs( + @varname(s) => MH(), + (@varname(s), @varname(m)) => MH(), + @varname(m) => ESS(), + @varname(s) => RepeatSampler(MH(), 3), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), + ) + chain = sample(gdemo(1.5, 2.0), alg, 500) + check_gdemo(chain; atol=0.15) + end + + @testset "Multiple overlapping samplers on MoGtest_default" begin + gibbs = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(z1), @varname(z2)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), + (@varname(mu1)) => ESS(), + (@varname(mu2)) => ESS(), + (@varname(z1), @varname(z2)) => PG(15), + ) + chain = sample(MoGtest_default, gibbs, 500) + check_MoGtest_default(chain; atol=0.15) + end end @testset "transitions" begin From 1f5432f7681975b0ddf269a1c799e99546b48aa2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 29 Nov 2024 16:55:28 +0000 Subject: [PATCH 67/70] Reduce another iter count in Gibbs tests --- test/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 4da9fa6312..cdbaf6735f 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -638,7 +638,7 @@ end # `sample` Random.seed!(42) - chain = sample(model, alg, 10_000; progress=false) + chain = sample(model, alg, 1_000; progress=false) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) end From 5f21c8480f0710736d0933d92c3c0c5d25aa728c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 2 Dec 2024 15:11:20 +0000 Subject: [PATCH 68/70] Add an info print to Gibbs tests --- test/mcmc/gibbs.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index cdbaf6735f..f67e09d094 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -264,6 +264,7 @@ end end @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends + @info "Starting Gibbs tests with $adbackend" @testset "Deprecated Gibbs constructors" begin N = 10 @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) From a15ce2f49552e61be09634f6459eb2bb18f014f4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 3 Dec 2024 17:37:59 +0000 Subject: [PATCH 69/70] Use StableRNG, relax test tolerance --- test/mcmc/gibbs.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f67e09d094..731fab18d4 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -16,6 +16,7 @@ using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff import Mooncake +using StableRNGs: StableRNG using Test: @inferred, @test, @test_broken, @test_deprecated, @test_throws, @testset using Turing using Turing: Inference @@ -463,7 +464,10 @@ end # https://github.com/TuringLang/Turing.jl/issues/1725 # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); chn = sample( - model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), num_samples + StableRNG(23), + model, + Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), + num_samples, ) # The number of m variables that have a non-zero value in a sample. num_ms = count(ismissing.(Array(chn[:, (num_zs + 1):end, 1])); dims=2) @@ -476,7 +480,7 @@ end # the posterior is analytically known? Doing 10_000 samples to run the test suite # is not ideal # Issue ref: https://github.com/TuringLang/Turing.jl/issues/2402 - @test isapprox(mean(num_ms), 8.6087; atol=0.5) + @test isapprox(mean(num_ms), 8.6087; atol=0.8) @test isapprox(std(num_ms), 1.8865; atol=0.02) end From 96f8dd4cb827d8a6dbdf642f65193e70812ed767 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 18 Dec 2024 20:19:08 +0000 Subject: [PATCH 70/70] Fix a kwarg --- test/dynamicppl/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamicppl/compiler.jl b/test/dynamicppl/compiler.jl index 359beba620..7939c7beb1 100644 --- a/test/dynamicppl/compiler.jl +++ b/test/dynamicppl/compiler.jl @@ -177,7 +177,7 @@ const gdemo_default = gdemo_d() end @testset "sample" begin - alg = Gibbs(; m=HMC(0.2, 3), PG(10)) + alg = Gibbs(; m=HMC(0.2, 3), s=PG(10)) chn = sample(gdemo_default, alg, 1000) end