diff --git a/docs/src/gibbs.md b/docs/src/gibbs.md index bb2b345..9de77c0 100644 --- a/docs/src/gibbs.md +++ b/docs/src/gibbs.md @@ -8,7 +8,7 @@ LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, st This function takes the logdensity model and the state, and returns the log probability of the state. If `recompute_logp` is `true`, it should recompute the log probability of the state. -Otherwise, it should use the log probability stored in the state. +Otherwise, it could use the log probability stored in the state. ```julia Base.vec(state) @@ -20,9 +20,11 @@ This function takes the state and returns a vector of the parameter values store (state::StateType)(logp::Float64) ``` -This function takes the state and a log probability value, and updates the state with the new log probability. +This function takes the state and a log probability value, and returns a new state with the updated log probability. -These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose. +These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement. +The interface facilitates the implementation of "meta-algorithms" that combine different samplers. +We will demonstrate how it can be used to implement Gibbs sampling in the following sections. ## Using the `state` Interface for block sampling within Gibbs @@ -122,7 +124,7 @@ function LogDensityProblems.capabilities(::ConditionedHierNormal) end ``` -## Sampler Packages +### Implementing A Sampler with `AbstractMCMC` Interface To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. @@ -258,15 +260,11 @@ function compute_log_acceptance_ratio( end ``` -At last, we can proceed to implement the Gibbs sampler. +At last, we can proceed to implement a very simple Gibbs sampler. ```julia -""" - Gibbs(sampler_map::NamedTuple) - -A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. -""" struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + "Maps variables to their samplers." sampler_map::T end @@ -291,16 +289,18 @@ end Update the trace with the values from the MCMC states of the sub-problems. """ -function update_trace(trace::NamedTuple, gibbs_state::GibbsState) - for parameter_variable in keys(gibbs_state.mcmc_states) +function update_trace( + trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT} +) where {trace_names,TraceNT,StateNT,SizeNT} + for parameter_variable in fieldnames(StateNT) sub_state = gibbs_state.mcmc_states[parameter_variable] - sub_state_params = Base.vec(sub_state) - unflattened_sub_state_params = unflatten( - sub_state_params, - NamedTuple{(parameter_variable,)}(( - gibbs_state.variable_sizes[parameter_variable], - )), + sub_state_params_values = Base.vec(sub_state) + reshaped_sub_state_params_values = reshape( + sub_state_params_values, gibbs_state.variable_sizes[parameter_variable] ) + unflattened_sub_state_params = NamedTuple{(parameter_variable,)}(( + reshaped_sub_state_params_values, + )) trace = merge(trace, unflattened_sub_state_params) end return trace @@ -321,8 +321,7 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Gibbs{Tsamplingmap}, - args...; + sampler::Gibbs{Tsamplingmap}; initial_params::NamedTuple, kwargs..., ) where {Tsamplingmap} @@ -338,30 +337,27 @@ function AbstractMCMC.step( conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) ) - sub_problem_parameters_values = NamedTuple{(parameter_variable,)}(( - initial_params[parameter_variable], - )) # LogDensityProblems' `logdensity` function expects a single vector of real numbers # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values # and unflatten after the sampling step - flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) + flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable]) + sub_logdensity_model = AbstractMCMC.LogDensityModel( + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ) sub_state = last( AbstractMCMC.step( rng, - AbstractMCMC.LogDensityModel( - AbstractPPL.condition( - logdensity_model.logdensity, conditioning_variables_values - ), - ), - sub_sampler, - args...; + sub_logdensity_model, + sub_sampler; initial_params=flattened_sub_problem_parameters_values, kwargs..., ), ) - (sub_state, Tuple(size(initial_params[parameter_variable]))) + (sub_state, size(initial_params[parameter_variable])) end mcmc_states_tuple = first.(results) @@ -382,11 +378,12 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, sampler::Gibbs{Tsamplingmap}, - gibbs_state::GibbsState, - args...; + gibbs_state::GibbsState; kwargs..., ) where {Tsamplingmap} - (; trace, mcmc_states, variable_sizes) = gibbs_state + trace = gibbs_state.trace + mcmc_states = gibbs_state.mcmc_states + variable_sizes = gibbs_state.variable_sizes model_parameter_names = fieldnames(Tsamplingmap) mcmc_states = map(model_parameter_names) do parameter_variable @@ -407,7 +404,7 @@ function AbstractMCMC.step( sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( - rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs... + rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... ), ) trace = update_trace(trace, gibbs_state) @@ -419,53 +416,36 @@ function AbstractMCMC.step( end ``` -where we use two utility functions `flatten` and `unflatten` to convert between the single vector of real numbers and the named tuple of parameters. +We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`. -```julia -""" - flatten(trace::NamedTuple) - -Flatten all the values in the trace into a single vector. Variable names information is discarded. -""" -function flatten(trace::NamedTuple) - return reduce(vcat, vec.(values(trace))) -end +We uses the `AbstractPPL.condition` to devide the full model into smaller conditional probability problems. +And each conditional probability problem corresponds to a sampler and corresponding state. -""" - unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple}) +The `Gibbs` sampler has the same interface as other samplers in `AbstractMCMC` (we don't implement the above state interface for `GibbsState` to keep it simple, but it can be implemented similarly). -Reverse operation of flatten. Reshape the vector into the original arrays using size information. -""" -function unflatten( - vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names} -) where {variable_names} - result = Dict{Symbol,Array}() - start_idx = 1 - for name in variable_names - size = variable_names_and_sizes[name] - end_idx = start_idx + prod(size) - 1 - result[name] = reshape(vec[start_idx:end_idx], size...) - start_idx = end_idx + 1 - end +The Gibbs sampler operates in two main phases: - return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names])) -end -``` +1. Initialization: + - Set up initial states for each conditional probability problem. -Some points worth noting: +2. Iterative Sampling: + For each iteration, the sampler performs a sweep over all conditional probability problems: -1. We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`. -2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration. -3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem. -4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following: - - condition on the values of all variables that are not in the current group - - recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed - - perform a step of the sampler for the conditional probability problem, and update the sampler state - - update the `vi` with the new values from the sampler state + a. Condition on other variables: + - Fix the values of all variables except the current one. + b. Update current variable: + - Recompute the log probability of the current state, as other variables may have changed: + - Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability. + - Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability. + - Perform a sampling step for the current conditional probability problem: + - Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state. + - Update the global trace: + - Extract parameter values from the new state using `Base.vec(new_sub_state)`. + - Incorporate these values into the overall Gibbs state trace. -The `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. +This process allows the Gibbs sampler to iteratively update each variable while conditioning on the others, gradually exploring the joint distribution of all variables. -Now we can use the Gibbs sampler to sample from the hierarchical normal model. +Now we can use the Gibbs sampler to sample from the hierarchical Normal model. First we generate some data, diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 0795e36..82d7b5a 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,14 +1,9 @@ using AbstractMCMC: AbstractMCMC using AbstractPPL: AbstractPPL -using MCMCChains: Chains using Random -""" - Gibbs(sampler_map::NamedTuple) - -A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. -""" struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + "Maps variables to their samplers." sampler_map::T end @@ -28,74 +23,23 @@ struct GibbsTransition{ValuesNT<:NamedTuple} values::ValuesNT end -""" - flatten(trace::NamedTuple) - -Flatten all the values in the trace into a single vector. Variable names information is discarded. - -# Examples - -```jldoctest; setup = :(using AbstractMCMC: flatten) -julia> flatten((a=ones(2), b=ones(2, 2))) -6-element Vector{Float64}: - 1.0 - 1.0 - 1.0 - 1.0 - 1.0 - 1.0 - -``` -""" -function flatten(trace::NamedTuple) - return reduce(vcat, vec.(values(trace))) -end - -""" - unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple}) - -Reverse operation of flatten. Reshape the vector into the original arrays using size information. - -# Examples - -```jldoctest; setup = :(using AbstractMCMC: unflatten) -julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,))) -(a = [1, 2], b = [3, 4, 5]) - -julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,))) -(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0]) -``` -""" -function unflatten( - vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names} -) where {variable_names} - result = Dict{Symbol,Array}() - start_idx = 1 - for name in variable_names - size = variable_names_and_sizes[name] - end_idx = start_idx + prod(size) - 1 - result[name] = reshape(vec[start_idx:end_idx], size...) - start_idx = end_idx + 1 - end - - return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names])) -end - """ update_trace(trace::NamedTuple, gibbs_state::GibbsState) Update the trace with the values from the MCMC states of the sub-problems. """ -function update_trace(trace::NamedTuple, gibbs_state::GibbsState) - for parameter_variable in keys(gibbs_state.mcmc_states) +function update_trace( + trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT} +) where {trace_names,TraceNT,StateNT,SizeNT} + for parameter_variable in fieldnames(StateNT) sub_state = gibbs_state.mcmc_states[parameter_variable] - sub_state_params = Base.vec(sub_state) - unflattened_sub_state_params = unflatten( - sub_state_params, - NamedTuple{(parameter_variable,)}(( - gibbs_state.variable_sizes[parameter_variable], - )), + sub_state_params_values = Base.vec(sub_state) + reshaped_sub_state_params_values = reshape( + sub_state_params_values, gibbs_state.variable_sizes[parameter_variable] ) + unflattened_sub_state_params = NamedTuple{(parameter_variable,)}(( + reshaped_sub_state_params_values, + )) trace = merge(trace, unflattened_sub_state_params) end return trace @@ -116,8 +60,7 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, - sampler::Gibbs{Tsamplingmap}, - args...; + sampler::Gibbs{Tsamplingmap}; initial_params::NamedTuple, kwargs..., ) where {Tsamplingmap} @@ -133,30 +76,27 @@ function AbstractMCMC.step( conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) ) - sub_problem_parameters_values = NamedTuple{(parameter_variable,)}(( - initial_params[parameter_variable], - )) # LogDensityProblems' `logdensity` function expects a single vector of real numbers # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values # and unflatten after the sampling step - flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) + flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable]) + sub_logdensity_model = AbstractMCMC.LogDensityModel( + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ) sub_state = last( AbstractMCMC.step( rng, - AbstractMCMC.LogDensityModel( - AbstractPPL.condition( - logdensity_model.logdensity, conditioning_variables_values - ), - ), - sub_sampler, - args...; + sub_logdensity_model, + sub_sampler; initial_params=flattened_sub_problem_parameters_values, kwargs..., ), ) - (sub_state, Tuple(size(initial_params[parameter_variable]))) + (sub_state, size(initial_params[parameter_variable])) end mcmc_states_tuple = first.(results) @@ -177,8 +117,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, logdensity_model::AbstractMCMC.LogDensityModel, sampler::Gibbs{Tsamplingmap}, - gibbs_state::GibbsState, - args...; + gibbs_state::GibbsState; kwargs..., ) where {Tsamplingmap} trace = gibbs_state.trace @@ -204,7 +143,7 @@ function AbstractMCMC.step( sub_state = (sub_state)(logp) sub_state = last( AbstractMCMC.step( - rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs... + rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... ), ) trace = update_trace(trace, gibbs_state) diff --git a/test/gibbs_example/gibbs_test.jl b/test/gibbs_example/gibbs_test.jl index eedb41d..48e679c 100644 --- a/test/gibbs_example/gibbs_test.jl +++ b/test/gibbs_example/gibbs_test.jl @@ -33,7 +33,7 @@ include("hier_normal.jl") tau2_mean = only(mean(tau2_samples)) @test mu_mean ≈ mu_true atol = 0.1 - @test tau2_mean ≈ tau2_true atol = 0.3 + @test tau2_mean ≈ tau2_true atol = 0.1 end # This is too difficult to sample, disable for now diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl index 2f58bf1..2e3e381 100644 --- a/test/gibbs_example/hier_normal.jl +++ b/test/gibbs_example/hier_normal.jl @@ -15,7 +15,7 @@ struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: end # `mu` and `tau2` are length-1 vectors to make -function log_joint(; mu, tau2, x) +function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) # mu is the mean # tau2 is the variance # x is data