diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 2772de28bf..768c43f990 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,11 +1,3 @@ style="blue" format_markdown = true import_to_using = false -# TODO -# We ignore these files because when formatting was first put in place they were being worked on. -# These ignores should be removed once the relevant PRs are merged/closed. -ignore = [ - # https://github.com/TuringLang/Turing.jl/pull/2328/files - "src/experimental/gibbs.jl", - "test/experimental/gibbs.jl", -] diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index a1fe49b463..f648f8b103 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -34,12 +34,10 @@ jobs: args: "mcmc/abstractmcmc.jl" - name: "mcmc/Inference" args: "mcmc/Inference.jl" - - name: "experimental/gibbs" - args: "experimental/gibbs.jl" - name: "mcmc/ess" args: "mcmc/ess.jl" - name: "everything else" - args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl" + args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl" runner: # Default - version: '1' diff --git a/HISTORY.md b/HISTORY.md index 3bc362d2b4..ff50fb7795 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# Release 0.36.0 + +## Breaking changes + +0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely. + +The new Gibbs sampler supports the same user-facing interface as the old one. However, given +that the internals of it having been completely rewritten in a very different manner, there +may be accidental breakage that we haven't anticipated. Please report any you find. + +`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. + +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. + +Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`. + # Release 0.35.0 ## Breaking changes diff --git a/Project.toml b/Project.toml index 53ca7d35dd..eea0b0b7ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.35.5" +version = "0.36.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -49,7 +49,7 @@ TuringOptimExt = "Optim" [compat] ADTypes = "1.9" -AbstractMCMC = "5.2" +AbstractMCMC = "5.5" Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" AdvancedMH = "0.8" diff --git a/src/Turing.jl b/src/Turing.jl index 08534aa4b2..6318e2bd52 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -55,7 +55,6 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation -include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release ########### @@ -88,7 +87,6 @@ export @model, # modelling Emcee, ESS, Gibbs, - GibbsConditional, HMC, # Hamiltonian-like sampling SGLD, SGHMC, @@ -99,6 +97,7 @@ export @model, # modelling SMC, CSMC, PG, + RepeatSampler, vi, # variational inference ADVI, sample, # inference diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl deleted file mode 100644 index 518538e6c3..0000000000 --- a/src/experimental/Experimental.jl +++ /dev/null @@ -1,16 +0,0 @@ -module Experimental - -using Random: Random -using AbstractMCMC: AbstractMCMC -using DynamicPPL: DynamicPPL, VarName -using Accessors: Accessors - -using DocStringExtensions: TYPEDFIELDS -using Distributions - -using ..Turing: Turing -using ..Turing.Inference: gibbs_rerun, InferenceAlgorithm - -include("gibbs.jl") - -end diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl deleted file mode 100644 index 596e6e283b..0000000000 --- a/src/experimental/gibbs.jl +++ /dev/null @@ -1,488 +0,0 @@ -# Basically like a `DynamicPPL.FixedContext` but -# 1. Hijacks the tilde pipeline to fix variables. -# 2. Computes the log-probability of the fixed variables. -# -# Purpose: avoid triggering resampling of variables we're conditioning on. -# - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. -# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to -# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable -# rather than only for the "true" observations. -# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline -# rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values - context::Ctx -end - -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) - -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::GibbsContext) = context.context -DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) - -# has and get -has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) -function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) -end - -get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) -function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(get_conditioned_gibbs, context), vns) -end - -# Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) -end - -function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) -end - -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -# Needed to support broadcasting over columns for `MultivariateDistribution`s. -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, - x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - -function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi -) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) -end - - -""" - preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) - -Returns the preferred value type for a variable with the given `varinfo`. -""" -preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict -preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple -function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) - # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. - namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,typeof(identity)} - end - return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict -end - -""" - condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) - -Return a `GibbsContext` with the given values treated as conditioned. - -# Arguments -- `context::DynamicPPL.AbstractContext`: The context to condition. -- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. - If multiple values are provided, we recursively condition on each of them. -""" -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. -function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs( - condition_gibbs(context, value), - values... - ) -end - -# For `DynamicPPL.AbstractVarInfo` we just extract the values. -""" - condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) - -Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. -""" -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) -end -function condition_gibbs( - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo, - varinfos::DynamicPPL.AbstractVarInfo... -) - return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) -end -# Allow calling this on a `DynamicPPL.Model` directly. -function condition_gibbs(model::DynamicPPL.Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end - - -""" - make_conditional_model(model, varinfo, varinfos) - -Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. - -# Examples -```julia-repl -julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); - -julia> # A separate varinfo for each variable in `model`. - varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); - -julia> # The varinfo we want to NOT condition on. - target_varinfo = first(varinfos); - -julia> # Results in a model with only `m` conditioned. - conditioned_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos); - -julia> result = conditioned_model(); - -julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` -true - -julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` -true -``` -""" -function make_conditional(model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos) - # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs( - model, - filter(Base.Fix1(!==, target_varinfo), varinfos)... - ) -end -# Assumes the ones given are the ones to condition on. -function make_conditional(model::DynamicPPL.Model, varinfos) - return condition_gibbs( - model, - varinfos... - ) -end - -# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` -# or an `AbstractInferenceAlgorithm`. -wrap_algorithm_maybe(x) = x -wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) - -""" - Gibbs - -A type representing a Gibbs sampler. - -# Fields -$(TYPEDFIELDS) -""" -struct Gibbs{V,A} <: InferenceAlgorithm - "varnames representing variables for each sampler" - varnames::V - "samplers for each entry in `varnames`" - samplers::A -end - -# NamedTuple -Gibbs(; algs...) = Gibbs(NamedTuple(algs)) -function Gibbs(algs::NamedTuple) - return Gibbs( - map(s -> VarName{s}(), keys(algs)), - map(wrap_algorithm_maybe, values(algs)), - ) -end - -# AbstractDict -function Gibbs(algs::AbstractDict) - return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) -end -function Gibbs(algs::Pair...) - return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) -end - -# TODO: Remove when no longer needed. -DynamicPPL.getspace(::Gibbs) = () - -struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} - vi::V - states::S -end - -_maybevec(x) = vec(x) # assume it's iterable -_maybevec(x::Tuple) = [x...] -_maybevec(x::VarName) = [x] - -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - vi_base::DynamicPPL.AbstractVarInfo; - initial_params=nothing, - kwargs..., -) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers - - # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(model) - - # Simple way of setting the initial parameters: set them in the `vi_base` - # if they are given so they propagate to the subset varinfos used by each sampler. - if initial_params !== nothing - vi_base = DynamicPPL.unflatten(vi_base, initial_params) - end - - # Create the varinfos for each sampler. - varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) - initial_params_all = if initial_params === nothing - fill(nothing, length(varnames)) - else - # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], varinfos) - end - - # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local - # Construct the conditional model. - model_local = make_conditional(model, varinfo_local, varinfos) - - # Take initial step. - new_state_local = last(AbstractMCMC.step( - rng, model_local, sampler_local; - # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. - # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, - kwargs... - )) - - # Return the new state and the invlinked `varinfo`. - vi_local_state = Turing.Inference.varinfo(new_state_local) - vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink(vi_local_state, sampler_local, model_local) - else - vi_local_state - end - return (new_state_local, vi_local_state_linked) - end - - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - state::GibbsState; - kwargs..., -) - alg = spl.alg - samplers = alg.samplers - states = state.states - varinfos = map(Turing.Inference.varinfo, state.states) - @assert length(samplers) == length(state.states) - - # TODO: move this into a recursive function so we can unroll when reasonable? - for index = 1:length(samplers) - # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, - model, - samplers, - states, - varinfos, - index; - kwargs..., - ) - - # Update the `states` and `varinfos`. - states = Accessors.setindex(states, new_state_local, index) - varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) - end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, - merge(vi_base, first(varinfos)), - firstindex(varinfos), - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) - # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide - # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact - # same `selector` as before but now with `rerun` set to `true` if needed. - return Accessors.@set sampler.selector.rerun = true -end - -# Interface we need a sampler to implement to work as a component in a Gibbs sampler. -""" - gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - -Check if the log-probability of the destination model needs to be recomputed. - -Defaults to `true` -""" -function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - return true -end - -# TODO: Remove `rng`? -function Turing.Inference.recompute_logprob!!( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler, - state -) - varinfo = Turing.Inference.varinfo(state) - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler) - # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed - # `varinfo`, even if `varinfo` was linked. - varinfo_new = last(DynamicPPL.evaluate!!( - model, - varinfo, - # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. - DynamicPPL.SamplingContext(rng, sampler_rerun) - )) - # Update the state we're about to use if need be. - # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - return Turing.Inference.gibbs_state(model, sampler, state, varinfo_new) -end - -function gibbs_step_inner( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - samplers, - states, - varinfos, - index; - kwargs..., -) - # Needs to do a a few things. - sampler_local = samplers[index] - state_local = states[index] - varinfo_local = varinfos[index] - - # Make sure that all `varinfos` are linked. - varinfos_invlinked = map(varinfos) do vi - # NOTE: This is immutable linking! - # TODO: Do we need the `istrans` check here or should we just always use `invlink`? - # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 - DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - end - varinfo_local_invlinked = varinfos_invlinked[index] - - # 1. Create conditional model. - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) - - # Extract the previous sampler and state. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - state_previous = states[index == 1 ? length(states) : index - 1] - - # 1. Re-run the sampler if needed. - if gibbs_requires_recompute_logprob( - model_local, - sampler_local, - sampler_previous, - state_local, - state_previous - ) - state_local = Turing.Inference.recompute_logprob!!( - rng, - model_local, - sampler_local, - state_local, - ) - end - - # 2. Take step with local sampler. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local, - state_local; - kwargs..., - ), - ) - - # 3. Extract the new varinfo. - # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = Turing.Inference.varinfo(new_state_local) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - - # TODO: alternatively, we can return `states_new, varinfos_new, index_new` - return (new_state_local, varinfo_local_state_invlinked) -end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 2716f18a19..5905b1686e 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -33,7 +33,7 @@ using StatsFuns: logsumexp using Random: AbstractRNG using DynamicPPL using AbstractMCMC: AbstractModel, AbstractSampler -using DocStringExtensions: TYPEDEF, TYPEDFIELDS +using DocStringExtensions: FIELDS, TYPEDEF, TYPEDFIELDS using DataStructures: OrderedSet using Accessors: Accessors @@ -62,7 +62,6 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling - GibbsConditional, HMC, SGLD, PolynomialStepsize, @@ -73,13 +72,13 @@ export InferenceAlgorithm, SMC, CSMC, PG, + RepeatSampler, Prior, assume, dot_assume, observe, dot_observe, predict, - isgibbscomponent, externalsampler ####################### @@ -92,6 +91,20 @@ abstract type Hamiltonian <: InferenceAlgorithm end abstract type StaticHamiltonian <: Hamiltonian end abstract type AdaptiveHamiltonian <: Hamiltonian end +# TODO(mhauru) Remove the below function once all the space/Selector stuff has been removed. +""" + drop_space(alg::InferenceAlgorithm) + +Return an `InferenceAlgorithm` like `alg`, but with all space information removed. +""" +function drop_space end + +function drop_space(sampler::Sampler) + return Sampler(drop_space(sampler.alg), sampler.selector) +end + +include("repeat_sampler.jl") + """ ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} @@ -133,6 +146,9 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain end end +# External samplers don't have notion of space to begin with. +drop_space(x::ExternalSampler) = x + DynamicPPL.getspace(::ExternalSampler) = () """ @@ -201,6 +217,8 @@ Algorithm for sampling from the prior. """ struct Prior <: InferenceAlgorithm end +drop_space(x::Prior) = x + function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -335,7 +353,7 @@ end function AbstractMCMC.sample( rng::AbstractRNG, model::AbstractModel, - sampler::Sampler{<:InferenceAlgorithm}, + sampler::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; @@ -447,7 +465,7 @@ getlogevidence(transitions, sampler, state) = missing function AbstractMCMC.bundle_samples( ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, chain_type::Type{MCMCChains.Chains}; save_state=false, @@ -510,7 +528,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, + spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, chain_type::Type{Vector{NamedTuple}}; kwargs..., @@ -560,22 +578,21 @@ end # Concrete algorithm implementations. # ####################################### +include("abstractmcmc.jl") include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") include("particle_mcmc.jl") -include("gibbs_conditional.jl") include("gibbs.jl") include("sghmc.jl") include("emcee.jl") -include("abstractmcmc.jl") ################ # Typing tools # ################ -for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs, :Emcee) +for alg in (:SMC, :PG, :MH, :IS, :ESS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index a350d2908f..aec7b153a9 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -27,6 +27,7 @@ function varinfo(state::TuringState) # TODO: Do we need to link here first? return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) end +varinfo(state::AbstractVarInfo) = state # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: @@ -53,51 +54,6 @@ function setvarinfo( ) end -""" - recompute_logprob!!(rng, model, sampler, state) - -Recompute the log-probability of the `model` based on the given `state` and return the resulting state. -""" -function recompute_logprob!!( - rng::Random.AbstractRNG, # TODO: Do we need the `rng` here? - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:ExternalSampler}, - state, -) - # Re-using the log-density function from the `state` and updating only the `model` field, - # since the `model` might now contain different conditioning values. - f = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) - # Recompute the log-probability with the new `model`. - state_inner = recompute_logprob!!( - rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state - ) - return state_to_turing(f, state_inner) -end - -function recompute_logprob!!( - rng::Random.AbstractRNG, - model::AbstractMCMC.LogDensityModel, - sampler::AdvancedHMC.AbstractHMCSampler, - state::AdvancedHMC.HMCState, -) - # Construct hamiltionian. - hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) - # Re-compute the log-probability and gradient. - return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, state.transition.z.θ, state.transition.z.r - ) -end - -function recompute_logprob!!( - rng::Random.AbstractRNG, - model::AbstractMCMC.LogDensityModel, - sampler::AdvancedMH.MetropolisHastings, - state::AdvancedMH.Transition, -) - logdensity = model.logdensity - return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params) -end - # TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index ebdfa041d7..816d90578a 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -26,6 +26,8 @@ function Emcee(n_walkers::Int, stretch_length=2.0) return Emcee{(),typeof(ensemble)}(ensemble) end +drop_space(alg::Emcee{space,E}) where {space,E} = Emcee{(),E}(alg.ensemble) + struct EmceeState{V<:AbstractVarInfo,S} vi::V states::S diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 395456ee5b..aa1a9fe380 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -25,6 +25,8 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() +drop_space(alg::ESS) = ESS() + # always accept in the first step function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 736845b678..0f2c78ebe8 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,251 +1,632 @@ -### -### Gibbs samplers / compositional samplers. -### - """ - isgibbscomponent(alg) + isgibbscomponent(alg::Union{InferenceAlgorithm, AbstractMCMC.AbstractSampler}) + +Return a boolean indicating whether `alg` is a valid component for a Gibbs sampler. -Determine whether algorithm `alg` is allowed as a Gibbs component. +Defaults to `false` if no method has been defined for a particular algorithm type. """ -isgibbscomponent(alg) = false +isgibbscomponent(::InferenceAlgorithm) = false +isgibbscomponent(spl::Sampler) = isgibbscomponent(spl.alg) isgibbscomponent(::ESS) = true -isgibbscomponent(::GibbsConditional) = true -isgibbscomponent(::Hamiltonian) = true +isgibbscomponent(::HMC) = true +isgibbscomponent(::HMCDA) = true +isgibbscomponent(::NUTS) = true isgibbscomponent(::MH) = true isgibbscomponent(::PG) = true -const TGIBBS = Union{InferenceAlgorithm,GibbsConditional} - +isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler) + +isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) +isgibbscomponent(::AdvancedHMC.HMC) = true +isgibbscomponent(::AdvancedMH.MetropolisHastings) = true + +# Basically like a `DynamicPPL.FixedContext` but +# 1. Hijacks the tilde pipeline to fix variables. +# 2. Computes the log-probability of the fixed variables. +# +# Purpose: avoid triggering resampling of variables we're conditioning on. +# - Using standard `DynamicPPL.condition` results in conditioned variables being treated +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to +# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable +# rather than only for the "true" observations. +# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline +# rather than the `observe` pipeline for the conditioned variables. """ - Gibbs(algs...) + GibbsContext{VNs}(global_varinfo, context) -Compositional MCMC interface. Gibbs sampling combines one or more -sampling algorithms, each of which samples from a different set of -variables in a model. +A context used in the implementation of the Turing.jl Gibbs sampler. -Example: -```julia -@model function gibbs_example(x) - v1 ~ Normal(0,1) - v2 ~ Categorical(5) -end +There will be one `GibbsContext` for each iteration of a component sampler. -# Use PG for a 'v2' variable, and use HMC for the 'v1' variable. -# Note that v2 is discrete, so the PG sampler is more appropriate -# than is HMC. -alg = Gibbs(HMC(0.2, 3, :v1), PG(20, :v2)) -``` +`VNs` is a a tuple of symbols for `VarName`s that the current component +sampler is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume` +calls to its child context. For other variables, their values will be fixed to the values +they have in `global_varinfo`. -One can also pass the number of iterations for each Gibbs component using the following syntax: -- `alg = Gibbs((HMC(0.2, 3, :v1), n_hmc), (PG(20, :v2), n_pg))` -where `n_hmc` and `n_pg` are the number of HMC and PG iterations for each Gibbs iteration. +The naive implementation of `GibbsContext` would simply have a field `target_varnames` that +would be a collection of `VarName`s that the current component sampler is sampling. The +reason we instead have a `Tuple` type parameter listing `Symbol`s is to allow +`is_target_varname` to benefit from compile time constant propagation. This is important +for type stability of `tilde_assume`. -Tips: -- `HMC` and `NUTS` are fast samplers and can throw off particle-based -methods like Particle Gibbs. You can increase the effectiveness of particle sampling by including -more particles in the particle sampler. +# Fields +$(FIELDS) """ -struct Gibbs{space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} <: InferenceAlgorithm - algs::A # component sampling algorithms - iterations::B - function Gibbs{space,N,A,B}( - algs::A, iterations::B - ) where {space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} - all(isgibbscomponent, algs) || - error("all algorithms have to support Gibbs sampling") - return new{space,N,A,B}(algs, iterations) +struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: + DynamicPPL.AbstractContext + """ + a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both + those fixed and those being sampled. We use a `Ref` because this field may need to be + updated if new variables are introduced. + """ + global_varinfo::GVI + """ + the child context that tilde calls will eventually be passed onto. + """ + context::Ctx + + function GibbsContext{VNs}(global_varinfo, context) where {VNs} + if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf) + error("GibbsContext can only wrap a leaf context, not a $(context).") + end + return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context) + end + + function GibbsContext(target_varnames, global_varinfo, context) + if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf) + error("GibbsContext can only wrap a leaf context, not a $(context).") + end + if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames) + msg = + "All Gibbs target variables must have identity lenses. " * + "For example, you can't have `@varname(x.a[1])` as a target variable, " * + "only `@varname(x)`." + error(msg) + end + vn_sym = tuple(unique((DynamicPPL.getsym(vn) for vn in target_varnames))...) + return new{vn_sym,typeof(global_varinfo),typeof(context)}(global_varinfo, context) end end -function Gibbs(alg1::TGIBBS, algrest::Vararg{TGIBBS,N}) where {N} - algs = (alg1, algrest...) - iterations = ntuple(Returns(1), Val(N + 1)) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +function GibbsContext(target_varnames, global_varinfo) + return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext()) end -function Gibbs(arg1::Tuple{<:TGIBBS,Int}, argrest::Vararg{Tuple{<:TGIBBS,Int},N}) where {N} - allargs = (arg1, argrest...) - algs = map(first, allargs) - iterations = map(last, allargs) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::GibbsContext) = context.context +function DynamicPPL.setchildcontext(context::GibbsContext{VNs}, childcontext) where {VNs} + return GibbsContext{VNs}(Ref(context.global_varinfo[]), childcontext) end -""" - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} +get_global_varinfo(context::GibbsContext) = context.global_varinfo[] -Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that -the `Gibbs` sampler iterates through for each `step!`. -""" -struct GibbsState{V<:VarInfo,S<:Tuple{Vararg{Sampler}},T} - vi::V - samplers::S - states::T +function set_global_varinfo!(context::GibbsContext, new_global_varinfo) + context.global_varinfo[] = new_global_varinfo + return nothing end -# extract varinfo object from state -""" - gibbs_varinfo(model, sampler, state) +# has and get +function has_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.haskey(get_global_varinfo(context), vn) +end +function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) + if (num_conditioned != 0) && (num_conditioned != length(vns)) + error( + "Some but not all of the variables in `vns` have been conditioned on. " * + "Having mixed conditioning like this is not supported in GibbsContext.", + ) + end + return num_conditioned > 0 +end -Return the variables corresponding to the current `state` of the Gibbs component `sampler`. -""" -gibbs_varinfo(model, sampler, state) = varinfo(state) -varinfo(state) = state.vi -varinfo(state::AbstractVarInfo) = state +function get_conditioned_gibbs(context::GibbsContext, vn::VarName) + return get_global_varinfo(context)[vn] +end +function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(get_conditioned_gibbs, context), vns) +end -""" - gibbs_state(model, sampler, state, varinfo) +is_target_varname(::GibbsContext{VNs}, ::VarName{sym}) where {VNs,sym} = sym in VNs -Return an updated state, taking into account the variables sampled by other Gibbs components. +function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) + num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) + if (num_target != 0) && (num_target != length(vns)) + error( + "Some but not all of the variables in `vns` are target variables. " * + "Having mixed targeting like this is not supported in GibbsContext.", + ) + end + return num_target > 0 +end -# Arguments -- `model`: model targeted by the Gibbs sampler. -- `sampler`: the sampler for this Gibbs component. -- `state`: the state of `sampler` computed in the previous iteration. -- `varinfo`: the variables, including the ones sampled by other Gibbs components. -""" -gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo -function gibbs_state(model, sampler, state::PGState, varinfo::AbstractVarInfo) - return PGState(varinfo, state.rng) +# Tilde pipeline +function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + child_context = DynamicPPL.childcontext(context) + return if is_target_varname(context, vn) + # Fall back to the default behavior. + DynamicPPL.tilde_assume(child_context, right, vn, vi) + elseif has_conditioned_gibbs(context, vn) + # Short-circuit the tilde assume if `vn` is present in `context`. + value, lp, _ = DynamicPPL.tilde_assume( + child_context, right, vn, get_global_varinfo(context) + ) + value, lp, vi + else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. + value, lp, new_global_vi = DynamicPPL.tilde_assume( + child_context, + DynamicPPL.SampleFromPrior(), + right, + vn, + get_global_varinfo(context), + ) + set_global_varinfo!(context, new_global_vi) + value, lp, vi + end end -# Update state in Gibbs sampling -function gibbs_state( - model::Model, spl::Sampler{<:Hamiltonian}, state::HMCState, varinfo::AbstractVarInfo +# As above but with an RNG. +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi ) - # Update hamiltonian - θ_old = varinfo[spl] - hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_old)) + # See comment in the above, rng-less version of this method for an explanation. + child_context = DynamicPPL.childcontext(context) + return if is_target_varname(context, vn) + DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) + elseif has_conditioned_gibbs(context, vn) + value, lp, _ = DynamicPPL.tilde_assume( + child_context, right, vn, get_global_varinfo(context) + ) + value, lp, vi + else + value, lp, new_global_vi = DynamicPPL.tilde_assume( + rng, + child_context, + DynamicPPL.SampleFromPrior(), + right, + vn, + get_global_varinfo(context), + ) + set_global_varinfo!(context, new_global_vi) + value, lp, vi + end +end - # TODO: Avoid mutation - resize!(state.z.θ, length(θ_old)) - state.z.θ .= θ_old - z = state.z +# Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf. +# See comments there for more details. +function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + child_context = DynamicPPL.childcontext(context) + return if is_target_varname(context, vns) + DynamicPPL.dot_tilde_assume(child_context, right, left, vns, vi) + elseif has_conditioned_gibbs(context, vns) + value, lp, _ = DynamicPPL.dot_tilde_assume( + child_context, right, left, vns, get_global_varinfo(context) + ) + value, lp, vi + else + value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( + child_context, + DynamicPPL.SampleFromPrior(), + right, + left, + vns, + get_global_varinfo(context), + ) + set_global_varinfo!(context, new_global_vi) + value, lp, vi + end +end - return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) +# As above but with an RNG. +function DynamicPPL.dot_tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi +) + child_context = DynamicPPL.childcontext(context) + return if is_target_varname(context, vns) + DynamicPPL.dot_tilde_assume(rng, child_context, sampler, right, left, vns, vi) + elseif has_conditioned_gibbs(context, vns) + value, lp, _ = DynamicPPL.dot_tilde_assume( + child_context, right, left, vns, get_global_varinfo(context) + ) + value, lp, vi + else + value, lp, new_global_vi = DynamicPPL.dot_tilde_assume( + rng, + child_context, + DynamicPPL.SampleFromPrior(), + right, + left, + vns, + get_global_varinfo(context), + ) + set_global_varinfo!(context, new_global_vi) + value, lp, vi + end end """ - gibbs_rerun(prev_alg, alg) + make_conditional(model, target_variables, varinfo) -Check if the model should be rerun to recompute the log density before sampling with the -Gibbs component `alg` and after sampling from Gibbs component `prev_alg`. +Return a new, conditioned model for a component of a Gibbs sampler. -By default, the function returns `true`. +# Arguments +- `model::DynamicPPL.Model`: The model to condition. +- `target_variables::AbstractVector{<:VarName}`: The target variables of the component +sampler. These will _not_ be conditioned. +- `varinfo::DynamicPPL.AbstractVarInfo`: Values for all variables in the model. All the +values in `varinfo` but not in `target_variables` will be conditioned to the values they +have in `varinfo`. + +# Returns +- A new model with the variables _not_ in `target_variables` conditioned. +- The `GibbsContext` object that will be used to condition the variables. This is necessary +because evaluation can mutate its `global_varinfo` field, which we need to access later. """ -gibbs_rerun(prev_alg, alg) = true - -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or one of the standard `Hamiltonian` algorithms -gibbs_rerun(::GibbsConditional, ::MH) = false -gibbs_rerun(::Hamiltonian, ::MH) = false +function make_conditional( + model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo +) + # Insert the `GibbsContext` just before the leaf. + # 1. Extract the `leafcontext` from `model` and wrap in `GibbsContext`. + gibbs_context_inner = GibbsContext( + target_variables, Ref(varinfo), DynamicPPL.leafcontext(model.context) + ) + # 2. Set the leaf context to be the `GibbsContext` wrapping `leafcontext(model.context)`. + gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) + return DynamicPPL.contextualize(model, gibbs_context), gibbs_context_inner +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or a `MH` algorithm -gibbs_rerun(::MH, ::Hamiltonian) = false -gibbs_rerun(::GibbsConditional, ::Hamiltonian) = false +# All samplers are given the same Selector, so that they will sample all variables +# given to them by the Gibbs sampler. This avoids conflicts between the new and the old way +# of choosing which sampler to use. +function set_selector(x::DynamicPPL.Sampler) + return DynamicPPL.Sampler(x.alg, DynamicPPL.Selector(0)) +end +function set_selector(x::RepeatSampler) + return RepeatSampler(set_selector(x.sampler), x.num_repeat) +end +set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) -# do not have to recompute `vi.logp` since it is not used in `step` -gibbs_rerun(prev_alg, ::GibbsConditional) = false +""" + Gibbs -# Do not recompute `vi.logp` since it is reset anyway in `step` -gibbs_rerun(prev_alg, ::PG) = false +A type representing a Gibbs sampler. -# Initialize the Gibbs sampler. -function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, vi::AbstractVarInfo; kwargs... -) - # TODO: Technically this only works for `VarInfo` or `ThreadSafeVarInfo{<:VarInfo}`. - # Should we enforce this? - - # Create tuple of samplers - algs = spl.alg.algs - i = 0 - samplers = map(algs) do alg - i += 1 - if i == 1 - prev_alg = algs[end] - else - prev_alg = algs[i - 1] +# Fields +$(TYPEDFIELDS) +""" +struct Gibbs{V,A} <: InferenceAlgorithm + "varnames representing variables for each sampler" + varnames::V + "samplers for each entry in `varnames`" + samplers::A + + function Gibbs(varnames, samplers) + if length(varnames) != length(samplers) + throw(ArgumentError("Number of varnames and samplers must match.")) + end + for spl in samplers + if !isgibbscomponent(spl) + msg = "All samplers must be valid Gibbs components, $(spl) is not." + throw(ArgumentError(msg)) + end end - rerun = gibbs_rerun(prev_alg, alg) - selector = DynamicPPL.Selector(Symbol(typeof(alg)), rerun) - Sampler(alg, model, selector) + return new{typeof(varnames),typeof(samplers)}(varnames, samplers) end +end - # Add Gibbs to gids for all variables. - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns +to_varname(vn::VarName) = vn +to_varname(s::Symbol) = VarName{s}() +# Any other value is assumed to be an iterable. +to_varname(t) = map(to_varname, collect(t)) - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) +# NamedTuple +Gibbs(; algs...) = Gibbs(NamedTuple(algs)) +function Gibbs(algs::NamedTuple) + return Gibbs(map(to_varname, keys(algs)), map(set_selector ∘ drop_space, values(algs))) +end - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end +# AbstractDict +function Gibbs(algs::AbstractDict) + return Gibbs( + map(to_varname, collect(keys(algs))), map(set_selector ∘ drop_space, values(algs)) + ) +end +function Gibbs(algs::Pair...) + return Gibbs(map(to_varname ∘ first, algs), map(set_selector ∘ drop_space ∘ last, algs)) +end + +# The below two constructors only provide backwards compatibility with the constructor of +# the old Gibbs sampler. They are deprecated and will be removed in the future. +function Gibbs(algs::InferenceAlgorithm...) + varnames = map(algs) do alg + space = getspace(alg) + if (space isa VarName) + space + elseif (space isa Symbol) + VarName{space}() + else + tuple((s isa Symbol ? VarName{s}() : s for s in space)...) end end + msg = ( + "Specifying which sampler to use with which variable using syntax like " * + "`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " * + "Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " * + "counts for different subsamplers, use e.g. " * + "`Gibbs(@varname(x) => RepeatSampler(NUTS(), 2), @varname(y) => MH())`" + ) + Base.depwarn(msg, :Gibbs) + return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) +end - # Compute initial states of the local samplers. - states = map(samplers) do local_spl - # Recompute `vi.logp` if needed. - if local_spl.selector.rerun - vi = last( - DynamicPPL.evaluate!!( - model, vi, DynamicPPL.SamplingContext(rng, local_spl) - ), - ) - end +function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) + algs = Iterators.map(first, algs_with_iters) + iters = Iterators.map(last, algs_with_iters) + algs_duplicated = Iterators.flatten(( + Iterators.repeated(alg, iter) for (alg, iter) in zip(algs, iters) + )) + # This calls the other deprecated constructor from above, hence no need for a depwarn + # here. + return Gibbs(algs_duplicated...) +end + +# TODO: Remove when no longer needed. +DynamicPPL.getspace(::Gibbs) = () + +struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} + vi::V + states::S +end - # Compute initial state. - _, state = DynamicPPL.initialstep(rng, model, local_spl, vi; kwargs...) +_maybevec(x) = vec(x) # assume it's iterable +_maybevec(x::Tuple) = [x...] +_maybevec(x::VarName) = [x] +_maybevec(x::Symbol) = [x] - # Update `VarInfo` object. - vi = gibbs_varinfo(model, local_spl, state) +varinfo(state::GibbsState) = state.vi - return state +function DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + vi_base::DynamicPPL.AbstractVarInfo; + initial_params=nothing, + kwargs..., +) + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers + + # Run the model once to get the varnames present + initial values to condition on. + vi = DynamicPPL.VarInfo(rng, model) + if initial_params !== nothing + vi = DynamicPPL.unflatten(vi, initial_params) end - # Compute initial transition and state. - transition = Transition(model, vi) - state = GibbsState(vi, samplers, states) + # Initialise each component sampler in turn, collect all their states. + states = [] + for (varnames_local, sampler_local) in zip(varnames, samplers) + varnames_local = _maybevec(varnames_local) + # Get the initial values for this component sampler. + initial_params_local = if initial_params === nothing + nothing + else + DynamicPPL.subset(vi, varnames_local)[:] + end - return transition, state + # Construct the conditioned model. + model_local, context_local = make_conditional(model, varnames_local, vi) + + # Take initial step. + _, new_state_local = AbstractMCMC.step( + rng, + model_local, + sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs..., + ) + new_vi_local = varinfo(new_state_local) + # Merge in any new variables that were introduced during the step, but that + # were not in the domain of the current sampler. + vi = merge(vi, get_global_varinfo(context_local)) + # Merge the new values for all the variables sampled by the current sampler. + vi = merge(vi, new_vi_local) + push!(states, new_state_local) + end + return Transition(model, vi), GibbsState(vi, states) end -# Subsequent steps function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, state::GibbsState; kwargs... + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + state::GibbsState; + kwargs..., ) - # Iterate through each of the samplers. - vi = state.vi - samplers = state.samplers - states = map(samplers, spl.alg.iterations, state.states) do _sampler, iteration, _state - # Recompute `vi.logp` if needed. - if _sampler.selector.rerun - vi = last(DynamicPPL.evaluate!!(model, rng, vi, _sampler)) - end + vi = varinfo(state) + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers + states = state.states + @assert length(samplers) == length(state.states) + + # TODO: move this into a recursive function so we can unroll when reasonable? + for index in 1:length(samplers) + # Take the inner step. + sampler_local = samplers[index] + state_local = states[index] + varnames_local = _maybevec(varnames[index]) + vi, new_state_local = gibbs_step_inner( + rng, model, varnames_local, sampler_local, state_local, vi; kwargs... + ) + states = Accessors.setindex(states, new_state_local, index) + end + return Transition(model, vi), GibbsState(vi, states) +end - # Update state of current sampler with updated `VarInfo` object. - current_state = gibbs_state(model, _sampler, _state, vi) +""" + setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) - # Step through the local sampler. - newstate = current_state - for _ in 1:iteration - _, newstate = AbstractMCMC.step(rng, model, _sampler, newstate; kwargs...) - end +A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an +`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to +`AbstractMCMC.setparams!!(model, state, params[:])`. - # Update `VarInfo` object. - vi = gibbs_varinfo(model, _sampler, newstate) +`model` is typically a `DynamicPPL.Model`, but can also be e.g. an +`AbstractMCMC.LogDensityModel`. +""" +function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) + return AbstractMCMC.setparams!!(model, state, params[:]) +end - return newstate +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::Sampler{<:MH}, + state::AbstractVarInfo, + params::AbstractVarInfo, +) + # The state is already a VarInfo, so we can just return `params`, but first we need to + # update its logprob. + # NOTE: Using `leafcontext(model.context)` here is a no-op, as it will be concatenated + # with `model.context` before hitting `model.f`. + return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) +end + +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::Sampler{<:ESS}, + state::AbstractVarInfo, + params::AbstractVarInfo, +) + # The state is already a VarInfo, so we can just return `params`, but first we need to + # update its logprob. To do this, we have to call evaluate!! with the sampler, rather + # than just a context, because ESS is peculiar in how it uses LikelihoodContext for + # some variables and DefaultContext for others. + return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler))) +end + +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::Sampler{<:ExternalSampler}, + state::TuringState, + params::AbstractVarInfo, +) + logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype) + new_inner_state = setparams_varinfo!!( + AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params + ) + return TuringState(new_inner_state, logdensity) +end + +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::Sampler{<:Hamiltonian}, + state::HMCState, + params::AbstractVarInfo, +) + θ_new = params[:] + hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) + + # Update the parameter values in `state.z`. + # TODO: Avoid mutation + z = state.z + resize!(z.θ, length(θ_new)) + z.θ .= θ_new + return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor) +end + +function setparams_varinfo!!( + model::DynamicPPL.Model, sampler::Sampler{<:PG}, state::PGState, params::AbstractVarInfo +) + return PGState(params, state.rng) +end + +""" + match_linking!!(varinfo_local, prev_state_local, model) + +Make sure the linked/invlinked status of varinfo_local matches that of the previous +state for this sampler. This is relevant when multilple samplers are sampling the same +variables, and one might need it to be linked while the other doesn't. +""" +function match_linking!!(varinfo_local, prev_state_local, model) + prev_varinfo_local = varinfo(prev_state_local) + was_linked = DynamicPPL.istrans(prev_varinfo_local) + is_linked = DynamicPPL.istrans(varinfo_local) + if was_linked && !is_linked + varinfo_local = DynamicPPL.link!!(varinfo_local, model) + elseif !was_linked && is_linked + varinfo_local = DynamicPPL.invlink!!(varinfo_local, model) end + # TODO(mhauru) The above might run into trouble if some variables are linked and others + # are not. `istrans(varinfo)` returns an `all` over the individual variables. This could + # especially be a problem with dynamic models, where new variables may get introduced, + # but also in cases where component samplers have partial overlap in their target + # variables. The below is how I would like to implement this, but DynamicPPL at this + # time does not support linking individual variables selected by `VarName`. It soon + # should though, so come back to this. + # Issue ref: https://github.com/TuringLang/Turing.jl/issues/2401 + # prev_links_dict = Dict(vn => DynamicPPL.istrans(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local)) + # any_linked = any(values(prev_links_dict)) + # for vn in keys(varinfo_local) + # was_linked = if haskey(prev_varinfo_local, vn) + # prev_links_dict[vn] + # else + # # If the old state didn't have this variable, we assume it was linked if _any_ + # # of the variables of the old state were linked. + # any_linked + # end + # is_linked = DynamicPPL.istrans(varinfo_local, vn) + # if was_linked && !is_linked + # varinfo_local = DynamicPPL.invlink!!(varinfo_local, vn) + # elseif !was_linked && is_linked + # varinfo_local = DynamicPPL.link!!(varinfo_local, vn) + # end + # end + return varinfo_local +end - return Transition(model, vi), GibbsState(vi, samplers, states) +function gibbs_step_inner( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + varnames_local, + sampler_local, + state_local, + global_vi; + kwargs..., +) + # Construct the conditional model and the varinfo that this sampler should use. + model_local, context_local = make_conditional(model, varnames_local, global_vi) + varinfo_local = subset(global_vi, varnames_local) + varinfo_local = match_linking!!(varinfo_local, state_local, model) + + # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not + # sampled by other samplers, we don't need to `setparams`, but could rather simply + # recompute the log probability. More over, in some cases the recomputation could also + # be avoided, if e.g. the previous sampler has done all the necessary work already. + # However, we've judged that doing any caching or other tricks to avoid this now would + # be premature optimization. In most use cases of Gibbs a single model call here is not + # going to be a significant expense anyway. + # Set the state of the current sampler, accounting for any changes made by other + # samplers. + state_local = setparams_varinfo!!( + model_local, sampler_local, state_local, varinfo_local + ) + + # Take a step with the local sampler. + new_state_local = last( + AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) + ) + + new_vi_local = varinfo(new_state_local) + # Merge the latest values for all the variables in the current sampler. + new_global_vi = merge(get_global_varinfo(context_local), new_vi_local) + new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + return new_global_vi, new_state_local end diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl deleted file mode 100644 index fda79315b2..0000000000 --- a/src/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,88 +0,0 @@ -""" - GibbsConditional(sym, conditional) - -A "pseudo-sampler" to manually provide analytical Gibbs conditionals to `Gibbs`. -`GibbsConditional(:x, cond)` will sample the variable `x` according to the conditional `cond`, which -must therefore be a function from a `NamedTuple` of the conditioned variables to a `Distribution`. - - -The `NamedTuple` that is passed in contains all random variables from the model in an unspecified -order, taken from the [`VarInfo`](@ref) object over which the model is run. Scalars and vectors are -stored in their respective shapes. The tuple also contains the value of the conditioned variable -itself, which can be useful, but using it creates something that is not a Gibbs sampler anymore (see -[here](https://github.com/TuringLang/Turing.jl/pull/1275#discussion_r434240387)). - -# Examples - -```julia -α_0 = 2.0 -θ_0 = inv(3.0) -x = [1.5, 2.0] -N = length(x) - -@model function inverse_gdemo(x) - λ ~ Gamma(α_0, θ_0) - σ = sqrt(1 / λ) - m ~ Normal(0, σ) - @. x ~ \$(Normal(m, σ)) -end - -# The conditionals can be formulated in terms of the following statistics: -x_bar = mean(x) # sample mean -s2 = var(x; mean=x_bar, corrected=false) # sample variance -m_n = N * x_bar / (N + 1) - -function cond_m(c) - λ_n = c.λ * (N + 1) - σ_n = sqrt(1 / λ_n) - return Normal(m_n, σ_n) -end - -function cond_λ(c) - α_n = α_0 + (N - 1) / 2 + 1 - β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0) - return Gamma(α_n, inv(β_n)) -end - -m = inverse_gdemo(x) - -sample(m, Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)), 10) -``` -""" -struct GibbsConditional{S,C} - conditional::C - - function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) - end -end - -DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,) - -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - return nothing, vi -end - -function AbstractMCMC.step( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - condvals = DynamicPPL.values_as(DynamicPPL.invlink(vi, model), NamedTuple) - conddist = spl.alg.conditional(condvals) - updated = rand(rng, conddist) - # Setindex allows only vectors in this case. - vi = setindex!!(vi, [updated;], spl) - # Update log joint probability. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - - return nothing, vi -end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 5f1caead27..80de196c6f 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -21,6 +21,8 @@ end ### Hamiltonian Monte Carlo samplers. ### +varinfo(state::HMCState) = state.vi + """ HMC(ϵ::Float64, n_leapfrog::Int; adtype::ADTypes.AbstractADType = AutoForwardDiff()) @@ -76,6 +78,10 @@ function HMC( return HMC(ϵ, n_leapfrog, metricT, space; adtype=adtype) end +function drop_space(alg::HMC{AD,space,metricT}) where {AD,space,metricT} + return HMC{AD,(),metricT}(alg.ϵ, alg.n_leapfrog, alg.adtype) +end + DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() # Handle setting `nadapts` and `discard_initial` @@ -376,6 +382,10 @@ function HMCDA( return HMCDA(n_adapts, δ, λ, init_ϵ, metricT, space; adtype=adtype) end +function drop_space(alg::HMCDA{AD,space,metricT}) where {AD,space,metricT} + return HMCDA{AD,(),metricT}(alg.n_adapts, alg.δ, alg.λ, alg.ϵ, alg.adtype) +end + """ NUTS(n_adapts::Int, δ::Float64; max_depth::Int=10, Δ_max::Float64=1000.0, init_ϵ::Float64=0.0; adtype::ADTypes.AbstractADType=AutoForwardDiff() @@ -453,6 +463,12 @@ function NUTS(; kwargs...) return NUTS(-1, 0.65; kwargs...) end +function drop_space(alg::NUTS{AD,space,metricT}) where {AD,space,metricT} + return NUTS{AD,(),metricT}( + alg.n_adapts, alg.δ, alg.max_depth, alg.Δ_max, alg.ϵ, alg.adtype + ) +end + for alg in (:HMC, :HMCDA, :NUTS) @eval getmetricT(::$alg{<:Any,<:Any,metricT}) where {metricT} = metricT end diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 0fe6e10535..083bc7bc3a 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -28,6 +28,8 @@ struct IS{space} <: InferenceAlgorithm end IS() = IS{()}() +drop_space(alg::IS) = IS() + DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index bc2519d71e..edd46a4572 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -20,7 +20,6 @@ Construct a Metropolis-Hastings algorithm. The arguments `space` can be - Blank (i.e. `MH()`), in which case `MH` defaults to using the prior for each parameter as the proposal distribution. -- A set of one or more symbols to sample with `MH` in conjunction with `Gibbs`, i.e. `Gibbs(MH(:m), PG(10, :s))` - An iterable of pairs or tuples mapping a `Symbol` to a `AdvancedMH.Proposal`, `Distribution`, or `Function` that generates returns a conditional proposal distribution. - A covariance matrix to use as for mean-zero multivariate normal proposals. @@ -41,15 +40,6 @@ chain = sample(gdemo(1.5, 2.0), MH(), 1_000) mean(chain) ``` -Alternatively, you can specify particular parameters to sample if you want to combine sampling -from multiple samplers: - -```julia -# Samples s² with MH and m with PG -chain = sample(gdemo(1.5, 2.0), Gibbs(MH(:s²), PG(10, :m)), 1_000) -mean(chain) -``` - Specifying a single distribution implies the use of static MH: ```julia @@ -155,6 +145,8 @@ function MH(space...) return MH{tuple(syms...),typeof(proposals)}(proposals) end +drop_space(alg::MH{space,P}) where {space,P} = MH{(),P}(alg.proposals) + # Some of the proposals require working in unconstrained space. transform_maybe(proposal::AMH.Proposal) = proposal function transform_maybe(proposal::AMH.RandomWalkProposal) @@ -260,7 +252,7 @@ function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVa end @generated function _val_tuple(vi::VarInfo, vns::NamedTuple{names}) where {names} - isempty(names) === 0 && return :(NamedTuple()) + isempty(names) && return :(NamedTuple()) expr = Expr(:tuple) expr.args = Any[ :( diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index a2b6757204..c5abb56f1d 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -45,6 +45,8 @@ end SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space) +drop_space(alg::SMC{space,R}) where {space,R} = SMC{(),R}(alg.resampler) + struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T @@ -220,6 +222,8 @@ function PG(nparticles::Int, space::Tuple) return PG(nparticles, AdvancedPS.ResampleWithESSThreshold(), space) end +drop_space(alg::PG{space,R}) where {space,R} = PG{(),R}(alg.nparticles, alg.resampler) + """ CSMC(...) @@ -241,6 +245,8 @@ struct PGState rng::Random.AbstractRNG end +varinfo(state::PGState) = state.vi + function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) theta = getparams(model, vi) diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl new file mode 100644 index 0000000000..a3e38f46a9 --- /dev/null +++ b/src/mcmc/repeat_sampler.jl @@ -0,0 +1,62 @@ +""" + RepeatSampler <: AbstractMCMC.AbstractSampler + +A `RepeatSampler` is a container for a sampler and a number of times to repeat it. + +# Fields +$(FIELDS) + +# Examples +```julia +repeated_sampler = RepeatSampler(sampler, 10) +AbstractMCMC.step(rng, model, repeated_sampler) # take 10 steps of `sampler` +``` +""" +struct RepeatSampler{S<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSampler + "The sampler to repeat" + sampler::S + "The number of times to repeat the sampler" + num_repeat::Int + + function RepeatSampler(sampler::S, num_repeat::Int) where {S} + @assert num_repeat > 0 + return new{S}(sampler, num_repeat) + end +end + +function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int) + return RepeatSampler(Sampler(alg), num_repeat) +end + +drop_space(rs::RepeatSampler) = RepeatSampler(drop_space(rs.sampler), rs.num_repeat) +getADType(spl::RepeatSampler) = getADType(spl.sampler) +DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler) +DynamicPPL.getspace(spl::RepeatSampler) = getspace(spl.sampler) +DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler) + +function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params) + return setparams_varinfo!!(model, sampler.sampler, state, params) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatSampler; + kwargs..., +) + return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatSampler, + state; + kwargs..., +) + transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...) + for _ in 2:(sampler.num_repeat) + transition, state = AbstractMCMC.step(rng, model, sampler.sampler, state; kwargs...) + end + return transition, state +end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index fbc4b11868..c79337c500 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -49,6 +49,10 @@ function SGHMC( ) end +function drop_space(alg::SGHMC{AD,space,T}) where {AD,space,T} + return SGHMC{AD,(),T}(alg.learning_rate, alg.momentum_decay, alg.adtype) +end + struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} logdensity::L vi::V @@ -130,6 +134,10 @@ struct SGLD{AD,space,S} <: StaticHamiltonian adtype::AD end +function drop_space(alg::SGLD{AD,space,S}) where {AD,space,S} + return SGLD{AD,(),S}(alg.stepsize, alg.adtype) +end + struct PolynomialStepsize{T<:Real} "Constant scale factor of the step size." a::T diff --git a/test/Project.toml b/test/Project.toml index 9253f8efa2..b6f8ba75a5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" diff --git a/test/dynamicppl/compiler.jl b/test/dynamicppl/compiler.jl index e063d8f9c6..7939c7beb1 100644 --- a/test/dynamicppl/compiler.jl +++ b/test/dynamicppl/compiler.jl @@ -46,7 +46,7 @@ const gdemo_default = gdemo_d() @model function testbb(obs) p ~ Beta(2, 2) x ~ Bernoulli(p) - for i in 1:length(obs) + for i in eachindex(obs) obs[i] ~ Bernoulli(p) end return p, x @@ -54,7 +54,7 @@ const gdemo_default = gdemo_d() smc = SMC() pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p), PG(10, :x)) + gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10)) chn_s = sample(testbb(obs), smc, 1000) chn_p = sample(testbb(obs), pg, 2000) @@ -73,7 +73,7 @@ const gdemo_default = gdemo_d() m ~ Normal(0, sqrt(s)) # xx ~ Normal(m, sqrt(s)) # this is illegal - for i in 1:length(xs) + for i in eachindex(xs) xs[i] ~ Normal(m, sqrt(s)) # for xx in xs # xx ~ Normal(m, sqrt(s)) @@ -81,7 +81,7 @@ const gdemo_default = gdemo_d() return s, m end - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m)) + gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8)) chain = sample(fggibbstest(xs), gibbs, 2) end @testset "new grammar" begin @@ -91,7 +91,7 @@ const gdemo_default = gdemo_d() priors = Array{Float64}(undef, 2) priors[1] ~ InverseGamma(2, 3) # s priors[2] ~ Normal(0, sqrt(priors[1])) # m - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(priors[2], sqrt(priors[1])) end return priors @@ -105,7 +105,7 @@ const gdemo_default = gdemo_d() priors = TV(undef, 2) priors[1] ~ InverseGamma(2, 3) # s priors[2] ~ Normal(0, sqrt(priors[1])) # m - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(priors[2], sqrt(priors[1])) end return priors @@ -126,7 +126,7 @@ const gdemo_default = gdemo_d() @model function newinterface(obs) p ~ Beta(2, 2) - for i in 1:length(obs) + for i in eachindex(obs) obs[i] ~ Bernoulli(p) end return p @@ -142,7 +142,7 @@ const gdemo_default = gdemo_d() @model function noreturn(x) s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(m, sqrt(s)) end end @@ -177,7 +177,7 @@ const gdemo_default = gdemo_d() end @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s)) + alg = Gibbs(; m=HMC(0.2, 3), s=PG(10)) chn = sample(gdemo_default, alg, 1000) end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 87779d0f0b..9356fbcc10 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -31,8 +31,8 @@ using Turing PG(10), IS(), MH(), - Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) for sampler in samplers Random.seed!(5) @@ -81,7 +81,7 @@ using Turing @testset "chain save/resume" begin alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) alg2 = PG(20) - alg3 = Gibbs(PG(30, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg3 = Gibbs(; s=PG(30), m=HMC(0.2, 4; adtype=adbackend)) chn1 = sample(StableRNG(seed), gdemo_default, alg1, 2_000; save_state=true) check_gdemo(chn1) @@ -260,7 +260,7 @@ using Turing smc = SMC() pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x)) + gibbs = Gibbs(; p=HMC(0.2, 3; adtype=adbackend), x=PG(10)) chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) @@ -288,7 +288,7 @@ using Turing return s, m end - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m; adtype=adbackend)) + gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8; adtype=adbackend)) chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2) end @@ -415,7 +415,7 @@ using Turing end @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) + alg = Gibbs(; m=HMC(0.2, 3; adtype=adbackend), s=PG(10)) chn = sample(StableRNG(seed), gdemo_default, alg, 10) end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 23f4a11ae8..92fcaf7d95 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -40,7 +40,7 @@ using Turing c3 = sample(demodot_default, s1, N) c4 = sample(demodot_default, s2, N) - s3 = Gibbs(ESS(:m), MH(:s)) + s3 = Gibbs(; m=ESS(), s=MH()) c5 = sample(gdemo_default, s3, N) end @@ -59,13 +59,17 @@ using Turing end @testset "gdemo with CSMC + ESS" begin - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 2000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @testset "MoGtest_default with CSMC + ESS" begin - alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) + alg = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) chain = sample(StableRNG(seed), MoGtest_default, alg, 2000) check_MoGtest_default(chain; atol=0.1) end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index cd044910b8..503bf16cb4 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -1,88 +1,411 @@ module GibbsTests -using ..Models: MoGtest_default, gdemo, gdemo_default -using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical +using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo, gdemo_default +using ..NumericalTests: + check_MoGtest_default, + check_MoGtest_default_z_vector, + check_gdemo, + check_numerical, + two_sample_test import ..ADUtils +import Combinatorics using Distributions: InverseGamma, Normal using Distributions: sample +using DynamicPPL: DynamicPPL using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff import Mooncake -using Test: @test, @testset +using StableRNGs: StableRNG +using Test: @inferred, @test, @test_broken, @test_deprecated, @test_throws, @testset using Turing using Turing: Inference +using Turing.Inference: AdvancedHMC, AdvancedMH using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess +function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) + transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val + [first(vn_and_val)] + end + # Varnames in `transition` should be subsumed by those in `parent_varnames`. + for vn in transition_varnames + @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) + end +end + +const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, +} +has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false +has_dot_assume(::DynamicPPL.Model) = true + +@testset "GibbsContext" begin + @testset "type stability" begin + # A test model that has multiple features in one package: + # Floats, Ints, arguments, observations, loops, dot_tildes. + @model function test_model(obs1, obs2, num_vars, mean) + variance ~ Exponential(2) + z = Vector{Float64}(undef, num_vars) + z .~ truncated(Normal(mean, variance); lower=1) + y = Vector{Int64}(undef, num_vars) + for i in 1:num_vars + y[i] ~ Poisson(Int(round(z[i]))) + end + s = sum(y) - sum(z) + obs1 ~ Normal(s, 1) + obs2 ~ Poisson(y[3]) + return obs1, obs2, variance, z, y, s + end + + model = test_model(1.2, 2, 10, 2.5) + all_varnames = DynamicPPL.VarName[@varname(variance), @varname(z), @varname(y)] + # All combinations of elements in all_varnames. + target_vn_combinations = Iterators.flatten( + Iterators.map( + n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) + ), + ) + + @testset "$(target_vns)" for target_vns in target_vn_combinations + global_varinfo = DynamicPPL.VarInfo(model) + target_vns = collect(target_vns) + local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) + ctx = Turing.Inference.GibbsContext( + target_vns, Ref(global_varinfo), Turing.DefaultContext() + ) + + # Check that the correct varnames are conditioned, and that getting their + # values is type stable when the varinfo is. + for k in keys(global_varinfo) + is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) + @test Turing.Inference.is_target_varname(ctx, k) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) + end + end + + # Check the type stability also in the dot_tilde pipeline. + for k in all_varnames + # The map(identity, ...) part is there to concretise the eltype. + subkeys = map( + identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) + ) + is_target = (k in target_vns) + @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) + end + end + + # Check that evaluate!! and the result it returns are type stable. + conditioned_model = DynamicPPL.contextualize(model, ctx) + _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( + conditioned_model, local_varinfo + ) + for k in keys(post_eval_varinfo) + @inferred post_eval_varinfo[k] + end + end + end +end + +@testset "Invalid Gibbs constructor" begin + # More samplers than varnames or vice versa + @test_throws ArgumentError Gibbs((@varname(s), @varname(m)), (NUTS(), NUTS(), NUTS())) + @test_throws ArgumentError Gibbs( + (@varname(s), @varname(m), @varname(x)), (NUTS(), NUTS()) + ) + # Invalid samplers + @test_throws ArgumentError Gibbs(@varname(s) => IS()) + @test_throws ArgumentError Gibbs(@varname(s) => Emcee(10, 2.0)) + @test_throws ArgumentError Gibbs( + @varname(s) => SGHMC(; learning_rate=0.01, momentum_decay=0.1) + ) + @test_throws ArgumentError Gibbs( + @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) + ) +end + +# Test that the samplers are being called in the correct order, on the correct target +# variables. +@testset "Sampler call order" begin + # A wrapper around inference algorithms to allow intercepting the dispatch cascade to + # collect testing information. + struct AlgWrapper{Alg<:Inference.InferenceAlgorithm} <: Inference.InferenceAlgorithm + inner::Alg + end + + unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) = + DynamicPPL.Sampler(sampler.alg.inner, sampler.selector) + + # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm. + # They all just propagate the call to the inner algorithm. + Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner) + Inference.drop_space(wrap::AlgWrapper) = AlgWrapper(Inference.drop_space(wrap.inner)) + function Inference.setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgWrapper}, + state, + params::Turing.AbstractVarInfo, + ) + return Inference.setparams_varinfo!!(model, unwrap_sampler(sampler), state, params) + end + + function target_vns(::Inference.GibbsContext{VNs}) where {VNs} + return VNs + end + + # targets_and_algs will be a list of tuples, where the first element is the target_vns + # of a component sampler, and the second element is the component sampler itself. + # It is modified by the capture_targets_and_algs function. + targets_and_algs = Any[] + + function capture_targets_and_algs(sampler, context) + if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf() + return nothing + end + if context isa Inference.GibbsContext + push!(targets_and_algs, (target_vns(context), sampler)) + end + return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context)) + end + + # The methods that capture testing information for us. + function Turing.AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgWrapper}, + args...; + kwargs..., + ) + capture_targets_and_algs(sampler.alg.inner, model.context) + return Turing.AbstractMCMC.step( + rng, model, unwrap_sampler(sampler), args...; kwargs... + ) + end + + function Turing.DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgWrapper}, + args...; + kwargs..., + ) + capture_targets_and_algs(sampler.alg.inner, model.context) + return Turing.DynamicPPL.initialstep( + rng, model, unwrap_sampler(sampler), args...; kwargs... + ) + end + + # A test model that includes several different kinds of tilde syntax. + @model function test_model(val, ::Type{M}=Vector{Float64}) where {M} + s ~ Normal(0.1, 0.2) + m ~ Poisson() + val ~ Normal(s, 1) + 1.0 ~ Normal(s + m, 1) + + n := m + 1 + xs = M(undef, n) + for i in eachindex(xs) + xs[i] ~ Beta(0.5, 0.5) + end + + ys = M(undef, 2) + ys .~ Beta(1.0, 1.0) + return sum(xs), sum(ys), n + end + + mh = MH() + pg = PG(10) + hmc = HMC(0.01, 4) + nuts = NUTS() + # Sample with all sorts of combinations of samplers and targets. + sampler = Gibbs( + (@varname(s),) => AlgWrapper(mh), + (@varname(s), @varname(m)) => AlgWrapper(mh), + (@varname(m),) => AlgWrapper(pg), + (@varname(xs),) => AlgWrapper(hmc), + (@varname(ys),) => AlgWrapper(nuts), + (@varname(ys),) => AlgWrapper(nuts), + (@varname(xs), @varname(ys)) => AlgWrapper(hmc), + (@varname(s),) => AlgWrapper(mh), + ) + chain = sample(test_model(-1), sampler, 2) + + expected_targets_and_algs_per_iteration = [ + ((:s,), mh), + ((:s, :m), mh), + ((:m,), pg), + ((:xs,), hmc), + ((:ys,), nuts), + ((:ys,), nuts), + ((:xs, :ys), hmc), + ((:s,), mh), + ] + @test targets_and_algs == vcat( + expected_targets_and_algs_per_iteration, expected_targets_and_algs_per_iteration + ) +end + +@testset "Equivalence of RepeatSampler and repeating Sampler" begin + sampler1 = Gibbs(@varname(s) => RepeatSampler(MH(), 3), @varname(m) => ESS()) + sampler2 = Gibbs( + @varname(s) => MH(), @varname(s) => MH(), @varname(s) => MH(), @varname(m) => ESS() + ) + Random.seed!(23) + chain1 = sample(gdemo_default, sampler1, 10) + Random.seed!(23) + chain2 = sample(gdemo_default, sampler1, 10) + @test chain1.value == chain2.value +end + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @testset "gibbs constructor" begin - N = 500 - s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) - s2 = Gibbs(PG(10, :s, :m)) - s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - for s in (s1, s2, s3, s4, s5, s6) + @info "Starting Gibbs tests with $adbackend" + @testset "Deprecated Gibbs constructors" begin + N = 10 + @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) + @test_deprecated s2 = Gibbs(PG(10, :s, :m)) + @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) + @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) + for s in (s1, s2, s3, s4, s5, s6, s7) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end - c1 = sample(gdemo_default, s1, N) - c2 = sample(gdemo_default, s2, N) - c3 = sample(gdemo_default, s3, N) - c4 = sample(gdemo_default, s4, N) - c5 = sample(gdemo_default, s5, N) - c6 = sample(gdemo_default, s6, N) + # Check that the samplers work despite using the deprecated constructor. + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) - # Test gid of each samplers g = Turing.Sampler(s3, gdemo_default) + @test sample(gdemo_default, g, N) isa MCMCChains.Chains + end - _, state = AbstractMCMC.step(Random.default_rng(), gdemo_default, g) - @test state.samplers[1].selector != g.selector - @test state.samplers[2].selector != g.selector - @test state.samplers[1].selector != state.samplers[2].selector + @testset "Gibbs constructors" begin + # Create Gibbs samplers with various configurations and ways of passing the + # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. + N = 10 + # Two variables being sampled by one sampler. + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) + s2 = Gibbs((@varname(s), :m) => PG(10)) + # One variable per sampler, using the keyword arg interface. + s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + # As above but using a Dict of VarNames. + s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + # As above but different samplers and using kwargs. + s5 = Gibbs(; s=CSMC(3), m=HMCDA(200, 0.65, 0.15; adtype=adbackend)) + s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) + s7 = Gibbs(Dict((:s, @varname(m)) => PG(10))) + # Multiple instnaces of the same sampler. This implements running, in this case, + # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. + s8 = begin + hmc = HMC(0.1, 5; adtype=adbackend) + pg = PG(10) + vns = @varname(s) + vnm = @varname(m) + Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) + end + # Same thing but using RepeatSampler. + s9 = Gibbs( + @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), + @varname(m) => RepeatSampler(PG(10), 2), + ) + for s in (s1, s2, s3, s4, s5, s6, s7, s8, s9) + @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" + end - # run sampler: progress logging should be disabled and - # it should return a Chains object + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) + sample(gdemo_default, s8, N) + sample(gdemo_default, s9, N) + + g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains end - @testset "gibbs inference" begin - Random.seed!(100) - alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:m], [7 / 6]; atol=0.15) - # Be more relaxed with the tolerance of the variance. - check_numerical(chain, [:s], [49 / 24]; atol=0.35) - Random.seed!(100) + # Test various combinations of samplers against models for which we know the analytical + # posterior mean. + @testset "Gibbs inference" begin + @testset "CSMC and HMC on gdemo" begin + alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) + chain = sample(gdemo(1.5, 2.0), alg, 3_000) + check_numerical(chain, [:m], [7 / 6]; atol=0.15) + # Be more relaxed with the tolerance of the variance. + check_numerical(chain, [:s], [49 / 24]; atol=0.35) + end - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + @testset "MH and HMCDA on gdemo" begin + alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend)) + chain = sample(gdemo(1.5, 2.0), alg, 3_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + end - alg = Gibbs(CSMC(15, :s), ESS(:m)) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + @testset "CSMC and ESS on gdemo" begin + alg = Gibbs(; s=CSMC(15), m=ESS()) + chain = sample(gdemo(1.5, 2.0), alg, 3_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + end - alg = CSMC(15) - chain = sample(gdemo(1.5, 2.0), alg, 10_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + # TODO(mhauru) Why is this in the Gibbs test suite? + @testset "CSMC on gdemo" begin + alg = CSMC(15) + chain = sample(gdemo(1.5, 2.0), alg, 4_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + end - Random.seed!(200) - gibbs = Gibbs( - PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2; adtype=adbackend) - ) - chain = sample(MoGtest_default, gibbs, 10_000) - check_MoGtest_default(chain; atol=0.15) + @testset "PG and HMC on MoGtest_default" begin + gibbs = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + ) + chain = sample(MoGtest_default, gibbs, 2_000) + check_MoGtest_default(chain; atol=0.15) + end - Random.seed!(200) - for alg in [ - Gibbs((MH(:s), 2), (HMC(0.2, 4, :m; adtype=adbackend), 1)), - Gibbs((MH(:s), 1), (HMC(0.2, 4, :m; adtype=adbackend), 2)), - ] - chain = sample(gdemo(1.5, 2.0), alg, 10_000) + @testset "Multiple overlapping samplers on gdemo" begin + # Test samplers that are run multiple times, or have overlapping targets. + alg = Gibbs( + @varname(s) => MH(), + (@varname(s), @varname(m)) => MH(), + @varname(m) => ESS(), + @varname(s) => RepeatSampler(MH(), 3), + @varname(m) => HMC(0.2, 4; adtype=adbackend), + (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), + ) + chain = sample(gdemo(1.5, 2.0), alg, 500) check_gdemo(chain; atol=0.15) end + + @testset "Multiple overlapping samplers on MoGtest_default" begin + gibbs = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(z1), @varname(z2)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), + (@varname(mu1)) => ESS(), + (@varname(mu2)) => ESS(), + (@varname(z1), @varname(z2)) => PG(15), + ) + chain = sample(MoGtest_default, gibbs, 500) + check_MoGtest_default(chain; atol=0.15) + end end @testset "transitions" begin @@ -112,9 +435,10 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess return nothing end - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) sample(model, alg, 100; callback=callback) end + @testset "dynamic model" begin @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} N = length(y) @@ -135,10 +459,304 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess m[k] ~ Normal(1.0, 1.0) end end - model = imm(randn(100), 1.0) + num_zs = 100 + num_samples = 10_000 + model = imm(Random.randn(num_zs), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 - # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100); - sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m; adtype=adbackend)), 100) + # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); + chn = sample( + StableRNG(23), + model, + Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), + num_samples, + ) + # The number of m variables that have a non-zero value in a sample. + num_ms = count(ismissing.(Array(chn[:, (num_zs + 1):end, 1])); dims=2) + # The below are regression tests. The values we are comparing against are from + # running the above model on the "old" Gibbs sampler that was in place still on + # 2024-11-20. The model was run 5 times with 10_000 samples each time. The values + # to compare to are the mean of those 5 runs, atol is roughly estimated from the + # standard deviation of those 5 runs. + # TODO(mhauru) Could we do something smarter here? Maybe a dynamic model for which + # the posterior is analytically known? Doing 10_000 samples to run the test suite + # is not ideal + # Issue ref: https://github.com/TuringLang/Turing.jl/issues/2402 + @test isapprox(mean(num_ms), 8.6087; atol=0.8) + @test isapprox(std(num_ms), 1.8865; atol=0.02) + end + + # The below test used to sample incorrectly before + # https://github.com/TuringLang/Turing.jl/pull/2328 + @testset "dynamic model with ESS" begin + @model function dynamic_model_for_ess() + b ~ Bernoulli() + x_length = b ? 1 : 2 + x = Vector{Float64}(undef, x_length) + for i in 1:x_length + x[i] ~ Normal(i, 1.0) + end + end + + m = dynamic_model_for_ess() + chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) + means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) + stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) + for vn in keys(means) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + end + end + + @testset "dynamic model with dot tilde" begin + @model function dynamic_model_with_dot_tilde( + num_zs=10, ::Type{M}=Vector{Float64} + ) where {M} + z = M(undef, num_zs) + z .~ Poisson(1.0) + num_ms = sum(z) + m = M(undef, num_ms) + return m .~ Normal(1.0, 1.0) + end + model = dynamic_model_with_dot_tilde() + # TODO(mhauru) This is broken because of + # https://github.com/TuringLang/DynamicPPL.jl/issues/700. + @test_broken ( + sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); + true + ) + end + + @testset "Demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + samplers = [ + Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()), + Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)), + Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => ESS()), + ] + + if !has_dot_assume(model) + # Add in some MH samplers, which are not compatible with `.~`. + append!( + samplers, + [ + Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()), + Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)), + ], + ) + end + + @testset "$sampler" for sampler in samplers + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler) + ) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler), state + ) + check_transition_varnames(transition, vns) + end + end + + # Run the Gibbs sampler and NUTS on the same model, compare statistics of the + # chains. + @testset "comparison with 'gold-standard' samples" begin + num_iterations = 1_000 + thinning = 10 + num_chains = 4 + + # Determine initial parameters to make comparison as fair as possible. + posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) + initial_params = DynamicPPL.TestUtils.update_values!!( + DynamicPPL.VarInfo(model), + posterior_mean, + DynamicPPL.TestUtils.varnames(model), + )[:] + initial_params = fill(initial_params, num_chains) + + # Sampler to use for Gibbs components. + hmc = HMC(0.1, 32) + sampler = Turing.Gibbs(@varname(s) => hmc, @varname(m) => hmc) + Random.seed!(42) + chain = sample( + model, + sampler, + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + discard_initial=1_000, + thinning=thinning, + ) + + # "Ground truth" samples. + # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. + Random.seed!(42) + chain_true = sample( + model, + NUTS(), + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + thinning=thinning, + ) + + # Perform KS test to ensure that the chains are similar. + xs = Array(chain) + xs_true = Array(chain_true) + for i in 1:size(xs, 2) + @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) + # Let's make sure that the significance level is not too low by + # checking that the KS test fails for some simple transformations. + # TODO: Replace the heuristic below with closed-form implementations + # of the targets, once they are implemented in DynamicPPL. + @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) + end + end + end + end + + @testset "multiple varnames" begin + rng = Random.default_rng() + + @testset "with both `s` and `m` as random" begin + model = gdemo(1.5, 2.0) + vns = (@varname(s), @varname(m)) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # `sample` + Random.seed!(42) + chain = sample(model, alg, 1_000; progress=false) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) + end + + @testset "without `m` as random" begin + model = gdemo(1.5, 2.0) | (m=7 / 6,) + vns = (@varname(s),) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + end + end + + @testset "CSMC + ESS" begin + rng = Random.default_rng() + model = MoGtest_default + alg = Turing.Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) + vns = ( + @varname(z1), + @varname(z2), + @varname(z3), + @varname(z4), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(MoGtest_default, alg, 1000; progress=false) + check_MoGtest_default(chain; atol=0.2) + end + + @testset "CSMC + ESS (usage of implicit varname)" begin + rng = Random.default_rng() + model = MoGtest_default_z_vector + alg = Turing.Gibbs( + @varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS() + ) + vns = ( + @varname(z[1]), + @varname(z[2]), + @varname(z[3]), + @varname(z[4]), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(model, alg, 1000; progress=false) + check_MoGtest_default_z_vector(chain; atol=0.2) + end + + @testset "externsalsampler" begin + @model function demo_gibbs_external() + m1 ~ Normal() + m2 ~ Normal() + + -1 ~ Normal(m1, 1) + +1 ~ Normal(m1 + m2, 1) + + return (; m1, m2) + end + + model = demo_gibbs_external() + samplers_inner = [ + externalsampler(AdvancedMH.RWMH(1)), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoForwardDiff()), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff()), + externalsampler( + AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff(; compile=true) + ), + ] + @testset "$(sampler_inner)" for sampler_inner in samplers_inner + sampler = Turing.Gibbs( + @varname(m1) => sampler_inner, @varname(m2) => sampler_inner + ) + Random.seed!(42) + chain = sample( + model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0 + ) + check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1) + end end end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl deleted file mode 100644 index d6d81cbe09..0000000000 --- a/test/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,171 +0,0 @@ -module GibbsConditionalTests - -using ..Models: gdemo, gdemo_default -using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils -using Clustering: Clustering -using Distributions: Categorical, InverseGamma, Normal, sample -using ForwardDiff: ForwardDiff -using LinearAlgebra: Diagonal, I -using Random: Random -using ReverseDiff: ReverseDiff -using StableRNGs: StableRNG -using StatsBase: counts -using StatsFuns: StatsFuns -import Mooncake -using Test: @test, @testset -using Turing - -@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends - Random.seed!(1000) - rng = StableRNG(123) - - @testset "gdemo" begin - # We consider the model - # ```math - # s ~ InverseGamma(2, 3) - # m ~ Normal(0, √s) - # xᵢ ~ Normal(m, √s), i = 1, …, N, - # ``` - # with ``N = 2`` observations ``x₁ = 1.5`` and ``x₂ = 2``. - - # The conditionals and posterior can be formulated in terms of the following statistics: - N = 2 - x_mean = 1.75 # sample mean ``∑ xᵢ / N`` - x_var = 0.0625 # sample variance ``∑ (xᵢ - x_bar)^2 / N`` - m_n = 3.5 / 3 # ``∑ xᵢ / (N + 1)`` - - # Conditional distribution - # ```math - # m | s, x ~ Normal(m_n, sqrt(s / (N + 1))) - # ``` - cond_m = let N = N, m_n = m_n - c -> Normal(m_n, sqrt(c.s / (N + 1))) - end - - # Conditional distribution - # ```math - # s | m, x ~ InverseGamma(2 + (N + 1) / 2, 3 + (m^2 + ∑ (xᵢ - m)^2) / 2) = - # InverseGamma(2 + (N + 1) / 2, 3 + m^2 / 2 + N / 2 * (x_var + (x_mean - m)^2)) - # ``` - cond_s = let N = N, x_mean = x_mean, x_var = x_var - c -> InverseGamma( - 2 + (N + 1) / 2, 3 + c.m^2 / 2 + N / 2 * (x_var + (x_mean - c.m)^2) - ) - end - - # Three Gibbs samplers: - # one for each variable fixed to the posterior mean - s_posterior_mean = 49 / 24 - sampler1 = Gibbs( - GibbsConditional(:m, cond_m), - GibbsConditional(:s, _ -> Normal(s_posterior_mean, 0)), - ) - chain = sample(rng, gdemo_default, sampler1, 10_000) - cond_m_mean = mean(cond_m((s=s_posterior_mean,))) - check_numerical(chain, [:m, :s], [cond_m_mean, s_posterior_mean]) - @test all(==(s_posterior_mean), chain[:s][2:end]) - - m_posterior_mean = 7 / 6 - sampler2 = Gibbs( - GibbsConditional(:m, _ -> Normal(m_posterior_mean, 0)), - GibbsConditional(:s, cond_s), - ) - chain = sample(rng, gdemo_default, sampler2, 10_000) - cond_s_mean = mean(cond_s((m=m_posterior_mean,))) - check_numerical(chain, [:m, :s], [m_posterior_mean, cond_s_mean]) - @test all(==(m_posterior_mean), chain[:m][2:end]) - - # and one for both using the conditional - sampler3 = Gibbs(GibbsConditional(:m, cond_m), GibbsConditional(:s, cond_s)) - chain = sample(rng, gdemo_default, sampler3, 10_000) - check_gdemo(chain) - end - - @testset "GMM" begin - Random.seed!(1000) - rng = StableRNG(123) - # We consider the model - # ```math - # μₖ ~ Normal(m, σ_μ), k = 1, …, K, - # zᵢ ~ Categorical(π), i = 1, …, N, - # xᵢ ~ Normal(μ_{zᵢ}, σₓ), i = 1, …, N, - # ``` - # with ``K = 2`` clusters, ``N = 20`` observations, and the following parameters: - K = 2 # number of clusters - π = fill(1 / K, K) # uniform cluster weights - m = 0.5 # prior mean of μₖ - σ²_μ = 4.0 # prior variance of μₖ - σ²_x = 0.01 # observation variance - N = 20 # number of observations - - # We generate data - μ_data = rand(rng, Normal(m, sqrt(σ²_μ)), K) - z_data = rand(rng, Categorical(π), N) - x_data = rand(rng, MvNormal(μ_data[z_data], σ²_x * I)) - - @model function mixture(x) - μ ~ $(MvNormal(fill(m, K), σ²_μ * I)) - z ~ $(filldist(Categorical(π), N)) - x ~ MvNormal(μ[z], $(σ²_x * I)) - return x - end - model = mixture(x_data) - - # Conditional distribution ``z | μ, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_z = let x = x_data, log_π = log.(π), σ_x = sqrt(σ²_x) - c -> begin - dists = map(x) do xi - logp = log_π .+ logpdf.(Normal.(c.μ, σ_x), xi) - return Categorical(StatsFuns.softmax!(logp)) - end - return arraydist(dists) - end - end - - # Conditional distribution ``μ | z, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_μ = let K = K, x_data = x_data, inv_σ²_μ = inv(σ²_μ), inv_σ²_x = inv(σ²_x) - c -> begin - # Convert cluster assignments to one-hot encodings - z_onehot = c.z .== (1:K)' - - # Count number of observations in each cluster - n = vec(sum(z_onehot; dims=1)) - - # Compute mean and variance of the conditional distribution - μ_var = @. inv(inv_σ²_x * n + inv_σ²_μ) - μ_mean = (z_onehot' * x_data) .* inv_σ²_x .* μ_var - - return MvNormal(μ_mean, Diagonal(μ_var)) - end - end - - estimate(chain, var) = dropdims(mean(Array(group(chain, var)); dims=1); dims=1) - function estimatez(chain, var, range) - z = Int.(Array(group(chain, var))) - return map(i -> findmax(counts(z[:, i], range))[2], 1:size(z, 2)) - end - - lμ_data, uμ_data = extrema(μ_data) - - # Compare three Gibbs samplers - sampler1 = Gibbs(GibbsConditional(:z, cond_z), GibbsConditional(:μ, cond_μ)) - sampler2 = Gibbs(GibbsConditional(:z, cond_z), MH(:μ)) - sampler3 = Gibbs(GibbsConditional(:z, cond_z), HMC(0.01, 7, :μ; adtype=adbackend)) - for sampler in (sampler1, sampler2, sampler3) - chain = sample(rng, model, sampler, 10_000) - - μ_hat = estimate(chain, :μ) - lμ_hat, uμ_hat = extrema(μ_hat) - @test isapprox([lμ_data, uμ_data], [lμ_hat, uμ_hat], atol=0.1) - - z_hat = estimatez(chain, :z, 1:2) - ari, _, _, _ = Clustering.randindex(z_data, Int.(z_hat)) - @test isapprox(ari, 1, atol=0.1) - end - end -end - -end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 6b07c73780..47ff73b1c5 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -146,7 +146,7 @@ using Turing # explicitly specifying the seeds here. @testset "hmcda+gibbs inference" begin Random.seed!(12345) - alg = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) + alg = Gibbs(; s=PG(20), m=HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend)) res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000) check_gdemo(res) end @@ -199,9 +199,9 @@ using Turing end @testset "AHMC resize" begin - alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s; adtype=adbackend)) - alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s; adtype=adbackend)) - alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s; adtype=adbackend)) + alg1 = Gibbs(; m=PG(10), s=NUTS(100, 0.65; adtype=adbackend)) + alg2 = Gibbs(; m=PG(10), s=HMC(0.1, 3; adtype=adbackend)) + alg3 = Gibbs(; m=PG(10), s=HMCDA(100, 0.65, 0.3; adtype=adbackend)) @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa Chains diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index d71a5fbc67..3823c2986c 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -34,7 +34,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - s4 = Gibbs(MH(:m), MH(:s)) + s4 = Gibbs(; m=MH(), s=MH()) c4 = sample(gdemo_default, s4, N) # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) @@ -69,7 +69,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) end @testset "gdemo_default with MH-within-Gibbs" begin - alg = Gibbs(MH(:m), MH(:s)) + alg = Gibbs(; m=MH(), s=MH()) chain = sample( StableRNG(seed), gdemo_default, alg, 10_000; discard_initial, initial_params ) @@ -78,7 +78,9 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @testset "MoGtest_default with Gibbs" begin gibbs = Gibbs( - CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1))) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => MH((:mu1, GKernel(1))), + @varname(mu2) => MH((:mu2, GKernel(1))), ) chain = sample( StableRNG(seed), @@ -175,9 +177,8 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # with small-valued VC matrix to check if we only see very small steps vc_μ = convert(Array, 1e-4 * I(2)) vc_σ = convert(Array, 1e-4 * I(2)) - alg_small = Gibbs(MH((:μ, vc_μ)), MH((:σ, vc_σ))) + alg_small = Gibbs(; μ=MH((:μ, vc_μ)), σ=MH((:σ, vc_σ))) alg_big = MH() - chn_small = sample(StableRNG(seed), mod, alg_small, 1_000) chn_big = sample(StableRNG(seed), mod, alg_big, 1_000) diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl new file mode 100644 index 0000000000..7328d1168c --- /dev/null +++ b/test/mcmc/repeat_sampler.jl @@ -0,0 +1,35 @@ +module RepeatSamplerTests + +using ..Models: gdemo_default +using DynamicPPL: Sampler +using StableRNGs: StableRNG +using Test: @test, @testset +using Turing + +# RepeatSampler only really makes sense as a component sampler of Gibbs. +# Here we just check that running it by itself is equivalent to thinning. +@testset "RepeatSampler" begin + num_repeats = 17 + num_samples = 10 + num_chains = 2 + + rng = StableRNG(0) + for sampler in [MH(), Sampler(HMC(0.01, 4))] + chn1 = sample( + copy(rng), + gdemo_default, + sampler, + MCMCThreads(), + num_samples, + num_chains; + thinning=num_repeats, + ) + repeat_sampler = RepeatSampler(sampler, num_repeats) + chn2 = sample( + copy(rng), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, num_chains + ) + @test chn1.value == chn2.value + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 12c6ccab87..093615e540 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,13 +53,13 @@ end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" verbose = true begin @timeit_include("mcmc/gibbs.jl") - @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") @timeit_include("mcmc/abstractmcmc.jl") @timeit_include("mcmc/mh.jl") @timeit_include("ext/dynamichmc.jl") + @timeit_include("mcmc/repeat_sampler.jl") end @testset "variational algorithms" begin @@ -72,10 +72,6 @@ end end end - @testset "experimental" begin - @timeit_include("experimental/gibbs.jl") - end - @testset "variational optimisers" begin @timeit_include("variational/optimisers.jl") end diff --git a/test/skipped/explicit_ret.jl b/test/skipped/explicit_ret.jl index c1340464fb..2dabc09bd3 100644 --- a/test/skipped/explicit_ret.jl +++ b/test/skipped/explicit_ret.jl @@ -12,7 +12,7 @@ end mf = test_ex_rt() for alg in - [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(PG(20, 1, :x), HMC(0.2, 3, :y))] + [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(; x=PG(20, 1), y=HMC(0.2, 3))] chn = sample(mf, alg) @test mean(chn[:x]) ≈ 10.0 atol = 0.2 @test mean(chn[:y]) ≈ 5.0 atol = 0.2