Skip to content

Commit

Permalink
update code and doc
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Sep 27, 2024
1 parent 3ed5cb3 commit 6fde198
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 161 deletions.
130 changes: 55 additions & 75 deletions docs/src/gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, st

This function takes the logdensity model and the state, and returns the log probability of the state.
If `recompute_logp` is `true`, it should recompute the log probability of the state.
Otherwise, it should use the log probability stored in the state.
Otherwise, it could use the log probability stored in the state.

```julia
Base.vec(state)
Expand All @@ -20,9 +20,11 @@ This function takes the state and returns a vector of the parameter values store
(state::StateType)(logp::Float64)
```

This function takes the state and a log probability value, and updates the state with the new log probability.
This function takes the state and a log probability value, and returns a new state with the updated log probability.

These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose.
These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement.
The interface facilitates the implementation of "meta-algorithms" that combine different samplers.
We will demonstrate how it can be used to implement Gibbs sampling in the following sections.

## Using the `state` Interface for block sampling within Gibbs

Expand Down Expand Up @@ -122,7 +124,7 @@ function LogDensityProblems.capabilities(::ConditionedHierNormal)
end
```

## Sampler Packages
### Implementing A Sampler with `AbstractMCMC` Interface

To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal.

Expand Down Expand Up @@ -258,15 +260,11 @@ function compute_log_acceptance_ratio(
end
```

At last, we can proceed to implement the Gibbs sampler.
At last, we can proceed to implement a very simple Gibbs sampler.

```julia
"""
Gibbs(sampler_map::NamedTuple)
A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
"""
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
"Maps variables to their samplers."
sampler_map::T
end

Expand All @@ -291,16 +289,18 @@ end
Update the trace with the values from the MCMC states of the sub-problems.
"""
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
for parameter_variable in keys(gibbs_state.mcmc_states)
function update_trace(
trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT}
) where {trace_names,TraceNT,StateNT,SizeNT}
for parameter_variable in fieldnames(StateNT)
sub_state = gibbs_state.mcmc_states[parameter_variable]
sub_state_params = Base.vec(sub_state)
unflattened_sub_state_params = unflatten(
sub_state_params,
NamedTuple{(parameter_variable,)}((
gibbs_state.variable_sizes[parameter_variable],
)),
sub_state_params_values = Base.vec(sub_state)
reshaped_sub_state_params_values = reshape(
sub_state_params_values, gibbs_state.variable_sizes[parameter_variable]
)
unflattened_sub_state_params = NamedTuple{(parameter_variable,)}((
reshaped_sub_state_params_values,
))
trace = merge(trace, unflattened_sub_state_params)
end
return trace
Expand All @@ -321,8 +321,7 @@ end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
logdensity_model::AbstractMCMC.LogDensityModel,
sampler::Gibbs{Tsamplingmap},
args...;
sampler::Gibbs{Tsamplingmap};
initial_params::NamedTuple,
kwargs...,
) where {Tsamplingmap}
Expand All @@ -338,30 +337,27 @@ function AbstractMCMC.step(
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
)
sub_problem_parameters_values = NamedTuple{(parameter_variable,)}((
initial_params[parameter_variable],
))

# LogDensityProblems' `logdensity` function expects a single vector of real numbers
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
# and unflatten after the sampling step
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)
flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable])

sub_logdensity_model = AbstractMCMC.LogDensityModel(
AbstractPPL.condition(
logdensity_model.logdensity, conditioning_variables_values
),
)
sub_state = last(
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
AbstractPPL.condition(
logdensity_model.logdensity, conditioning_variables_values
),
),
sub_sampler,
args...;
sub_logdensity_model,
sub_sampler;
initial_params=flattened_sub_problem_parameters_values,
kwargs...,
),
)
(sub_state, Tuple(size(initial_params[parameter_variable])))
(sub_state, size(initial_params[parameter_variable]))
end

mcmc_states_tuple = first.(results)
Expand All @@ -382,11 +378,12 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
logdensity_model::AbstractMCMC.LogDensityModel,
sampler::Gibbs{Tsamplingmap},
gibbs_state::GibbsState,
args...;
gibbs_state::GibbsState;
kwargs...,
) where {Tsamplingmap}
(; trace, mcmc_states, variable_sizes) = gibbs_state
trace = gibbs_state.trace
mcmc_states = gibbs_state.mcmc_states
variable_sizes = gibbs_state.variable_sizes

model_parameter_names = fieldnames(Tsamplingmap)
mcmc_states = map(model_parameter_names) do parameter_variable
Expand All @@ -407,7 +404,7 @@ function AbstractMCMC.step(
sub_state = (sub_state)(logp)
sub_state = last(
AbstractMCMC.step(
rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs...
rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...
),
)
trace = update_trace(trace, gibbs_state)
Expand All @@ -419,53 +416,36 @@ function AbstractMCMC.step(
end
```

where we use two utility functions `flatten` and `unflatten` to convert between the single vector of real numbers and the named tuple of parameters.
We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`.

```julia
"""
flatten(trace::NamedTuple)
Flatten all the values in the trace into a single vector. Variable names information is discarded.
"""
function flatten(trace::NamedTuple)
return reduce(vcat, vec.(values(trace)))
end
We uses the `AbstractPPL.condition` to devide the full model into smaller conditional probability problems.
And each conditional probability problem corresponds to a sampler and corresponding state.

"""
unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple})
The `Gibbs` sampler has the same interface as other samplers in `AbstractMCMC` (we don't implement the above state interface for `GibbsState` to keep it simple, but it can be implemented similarly).

Reverse operation of flatten. Reshape the vector into the original arrays using size information.
"""
function unflatten(
vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names}
) where {variable_names}
result = Dict{Symbol,Array}()
start_idx = 1
for name in variable_names
size = variable_names_and_sizes[name]
end_idx = start_idx + prod(size) - 1
result[name] = reshape(vec[start_idx:end_idx], size...)
start_idx = end_idx + 1
end
The Gibbs sampler operates in two main phases:

return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names]))
end
```
1. Initialization:
- Set up initial states for each conditional probability problem.

Some points worth noting:
2. Iterative Sampling:
For each iteration, the sampler performs a sweep over all conditional probability problems:

1. We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`.
2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration.
3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem.
4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following:
- condition on the values of all variables that are not in the current group
- recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed
- perform a step of the sampler for the conditional probability problem, and update the sampler state
- update the `vi` with the new values from the sampler state
a. Condition on other variables:
- Fix the values of all variables except the current one.
b. Update current variable:
- Recompute the log probability of the current state, as other variables may have changed:
- Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability.
- Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability.
- Perform a sampling step for the current conditional probability problem:
- Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state.
- Update the global trace:
- Extract parameter values from the new state using `Base.vec(new_sub_state)`.
- Incorporate these values into the overall Gibbs state trace.

The `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states.
This process allows the Gibbs sampler to iteratively update each variable while conditioning on the others, gradually exploring the joint distribution of all variables.

Now we can use the Gibbs sampler to sample from the hierarchical normal model.
Now we can use the Gibbs sampler to sample from the hierarchical Normal model.

First we generate some data,

Expand Down
107 changes: 23 additions & 84 deletions test/gibbs_example/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
using AbstractMCMC: AbstractMCMC
using AbstractPPL: AbstractPPL
using MCMCChains: Chains
using Random

"""
Gibbs(sampler_map::NamedTuple)
A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
"""
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
"Maps variables to their samplers."
sampler_map::T
end

Expand All @@ -28,74 +23,23 @@ struct GibbsTransition{ValuesNT<:NamedTuple}
values::ValuesNT
end

"""
flatten(trace::NamedTuple)
Flatten all the values in the trace into a single vector. Variable names information is discarded.
# Examples
```jldoctest; setup = :(using AbstractMCMC: flatten)
julia> flatten((a=ones(2), b=ones(2, 2)))
6-element Vector{Float64}:
1.0
1.0
1.0
1.0
1.0
1.0
```
"""
function flatten(trace::NamedTuple)
return reduce(vcat, vec.(values(trace)))
end

"""
unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple})
Reverse operation of flatten. Reshape the vector into the original arrays using size information.
# Examples
```jldoctest; setup = :(using AbstractMCMC: unflatten)
julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,)))
(a = [1, 2], b = [3, 4, 5])
julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0])
```
"""
function unflatten(
vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names}
) where {variable_names}
result = Dict{Symbol,Array}()
start_idx = 1
for name in variable_names
size = variable_names_and_sizes[name]
end_idx = start_idx + prod(size) - 1
result[name] = reshape(vec[start_idx:end_idx], size...)
start_idx = end_idx + 1
end

return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names]))
end

"""
update_trace(trace::NamedTuple, gibbs_state::GibbsState)
Update the trace with the values from the MCMC states of the sub-problems.
"""
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
for parameter_variable in keys(gibbs_state.mcmc_states)
function update_trace(
trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT}
) where {trace_names,TraceNT,StateNT,SizeNT}
for parameter_variable in fieldnames(StateNT)
sub_state = gibbs_state.mcmc_states[parameter_variable]
sub_state_params = Base.vec(sub_state)
unflattened_sub_state_params = unflatten(
sub_state_params,
NamedTuple{(parameter_variable,)}((
gibbs_state.variable_sizes[parameter_variable],
)),
sub_state_params_values = Base.vec(sub_state)
reshaped_sub_state_params_values = reshape(
sub_state_params_values, gibbs_state.variable_sizes[parameter_variable]
)
unflattened_sub_state_params = NamedTuple{(parameter_variable,)}((
reshaped_sub_state_params_values,
))
trace = merge(trace, unflattened_sub_state_params)
end
return trace
Expand All @@ -116,8 +60,7 @@ end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
logdensity_model::AbstractMCMC.LogDensityModel,
sampler::Gibbs{Tsamplingmap},
args...;
sampler::Gibbs{Tsamplingmap};
initial_params::NamedTuple,
kwargs...,
) where {Tsamplingmap}
Expand All @@ -133,30 +76,27 @@ function AbstractMCMC.step(
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
)
sub_problem_parameters_values = NamedTuple{(parameter_variable,)}((
initial_params[parameter_variable],
))

# LogDensityProblems' `logdensity` function expects a single vector of real numbers
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
# and unflatten after the sampling step
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)
flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable])

sub_logdensity_model = AbstractMCMC.LogDensityModel(
AbstractPPL.condition(
logdensity_model.logdensity, conditioning_variables_values
),
)
sub_state = last(
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
AbstractPPL.condition(
logdensity_model.logdensity, conditioning_variables_values
),
),
sub_sampler,
args...;
sub_logdensity_model,
sub_sampler;
initial_params=flattened_sub_problem_parameters_values,
kwargs...,
),
)
(sub_state, Tuple(size(initial_params[parameter_variable])))
(sub_state, size(initial_params[parameter_variable]))
end

mcmc_states_tuple = first.(results)
Expand All @@ -177,8 +117,7 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
logdensity_model::AbstractMCMC.LogDensityModel,
sampler::Gibbs{Tsamplingmap},
gibbs_state::GibbsState,
args...;
gibbs_state::GibbsState;
kwargs...,
) where {Tsamplingmap}
trace = gibbs_state.trace
Expand All @@ -204,7 +143,7 @@ function AbstractMCMC.step(
sub_state = (sub_state)(logp)
sub_state = last(
AbstractMCMC.step(
rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs...
rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...
),
)
trace = update_trace(trace, gibbs_state)
Expand Down
Loading

0 comments on commit 6fde198

Please sign in to comment.