Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Sep 20, 2024
1 parent 62a2332 commit 6132f0c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 44 deletions.
65 changes: 25 additions & 40 deletions test/gibbs_example/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
using AbstractMCMC, AbstractPPL
using BangBang.ConstructorBase: ConstructorBase

"""
Gibbs(sampler_map::NamedTuple)
An interface for block sampling in Markov Chain Monte Carlo (MCMC).
Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems.
It allows different sampling methods to be applied to different parameters.
A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
"""
struct Gibbs{NT<:NamedTuple} <: AbstractMCMC.AbstractSampler
sampler_map::NT
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
sampler_map::T
end

struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple}
"""
Contains the values of all parameters up to the last iteration.
"""
"Contains the values of all parameters up to the last iteration."
trace::TraceNT

"""
Maps parameters to their sampler-specific MCMC states.
"""
"Maps parameters to their sampler-specific MCMC states."
mcmc_states::StateNT

"""
Maps parameters to their sizes.
"""
"Maps parameters to their sizes."
variable_sizes::SizeNT
end

struct GibbsTransition{ValuesNT<:NamedTuple}
"""
Realizations of the parameters, this is considered a "sample" in the MCMC chain.
"""
"Realizations of the parameters, this is considered a \"sample\" in the MCMC chain."
values::ValuesNT
end

Expand Down Expand Up @@ -95,7 +87,7 @@ Update the trace with the values from the MCMC states of the sub-problems.
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
for parameter_variable in keys(gibbs_state.mcmc_states)
sub_state = gibbs_state.mcmc_states[parameter_variable]
sub_state_params = vec(sub_state)
sub_state_params = Base.vec(sub_state)
unflattened_sub_state_params = unflatten(
sub_state_params,
NamedTuple{(parameter_variable,)}((
Expand All @@ -115,21 +107,19 @@ function AbstractMCMC.step(
initial_params::NamedTuple,
kwargs...,
)
if Set(keys(initial_params)) != Set(sampler.parameter_names)
if Set(keys(initial_params)) != Set(keys(sampler.sampler_map))
throw(
ArgumentError(
"initial_params must contain all parameters in the model, expected $(sampler.parameter_names), got $(keys(initial_params))",
"initial_params must contain all parameters in the model, expected $(keys(sampler.sampler_map)), got $(keys(initial_params))",
),
)
end

mcmc_states = Dict{Symbol,Any}()
variable_sizes = Dict{Symbol,Tuple}()
for parameter_variable in sampler.parameter_names
mcmc_states, variable_sizes = map(keys(sampler.sampler_map)) do parameter_variable
sub_sampler = sampler.sampler_map[parameter_variable]

variables_to_be_conditioned_on = setdiff(
sampler.parameter_names, (parameter_variable,)
keys(sampler.sampler_map), (parameter_variable,)
)
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
Expand All @@ -141,7 +131,6 @@ function AbstractMCMC.step(
# LogDensityProblems' `logdensity` function expects a single vector of real numbers
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
# and unflatten after the sampling step
variable_sizes[parameter_variable] = Tuple(size(initial_params[parameter_variable]))
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)

sub_state = last(
Expand All @@ -158,11 +147,13 @@ function AbstractMCMC.step(
kwargs...,
),
)
mcmc_states[parameter_variable] = sub_state
(sub_state, Tuple(size(initial_params[parameter_variable])))
end

gibbs_state = GibbsState(
initial_params, NamedTuple(mcmc_states), NamedTuple(variable_sizes)
initial_params,
NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states),
NamedTuple{Tuple(keys(sampler.sampler_map))}(variable_sizes),
)
trace = update_trace(NamedTuple(), gibbs_state)
return GibbsTransition(trace), gibbs_state
Expand All @@ -176,14 +167,9 @@ function AbstractMCMC.step(
args...;
kwargs...,
)
trace = gibbs_state.trace
mcmc_states = gibbs_state.mcmc_states
variable_sizes = gibbs_state.variable_sizes
(; trace, mcmc_states, variable_sizes) = gibbs_state

mcmc_states_dict = Dict(
keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)]
)
for parameter_variable in sampler.parameter_names
mcmc_states = map(keys(sampler.sampler_map)) do parameter_variable
sub_sampler = sampler.sampler_map[parameter_variable]
sub_state = mcmc_states[parameter_variable]
variables_to_be_conditioned_on = setdiff(
Expand All @@ -196,7 +182,8 @@ function AbstractMCMC.step(
logdensity_model.logdensity, conditioning_variables_values
)

_, sub_state = AbstractMCMC.logdensity_and_state(cond_logdensity, sub_state)
logp = LogDensityProblems.logdensity_and_state(cond_logdensity, sub_state)
sub_state = constructorof(typeof(sub_state))(; logp=logp)
sub_state = last(
AbstractMCMC.step(
rng,
Expand All @@ -207,12 +194,10 @@ function AbstractMCMC.step(
kwargs...,
),
)
mcmc_states_dict[parameter_variable] = sub_state
trace = update_trace(trace, gibbs_state)
sub_state
end
mcmc_states = NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states)

mcmc_states = NamedTuple{Tuple(keys(mcmc_states_dict))}(
Tuple([mcmc_states_dict[k] for k in keys(mcmc_states_dict)])
)
return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes)
end
8 changes: 4 additions & 4 deletions test/gibbs_example/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ function Base.vec(state::MHState)
return state.params
end

struct RandomWalkMH <: AbstractMCMC.AbstractSampler
σ::Float64
struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
σ::T
end

struct IndependentMH <: AbstractMCMC.AbstractSampler
proposal_dist::Distributions.Distribution
struct IndependentMH{T} <: AbstractMCMC.AbstractSampler
proposal_dist::T
end

function AbstractMCMC.step(
Expand Down

0 comments on commit 6132f0c

Please sign in to comment.