diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl index 817eb89..44f2481 100644 --- a/test/gibbs_example/gibbs.jl +++ b/test/gibbs_example/gibbs.jl @@ -1,36 +1,28 @@ +using AbstractMCMC, AbstractPPL +using BangBang.ConstructorBase: ConstructorBase + """ Gibbs(sampler_map::NamedTuple) -An interface for block sampling in Markov Chain Monte Carlo (MCMC). - -Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems. -It allows different sampling methods to be applied to different parameters. +A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter. """ -struct Gibbs{NT<:NamedTuple} <: AbstractMCMC.AbstractSampler - sampler_map::NT +struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + sampler_map::T end struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple} - """ - Contains the values of all parameters up to the last iteration. - """ + "Contains the values of all parameters up to the last iteration." trace::TraceNT - """ - Maps parameters to their sampler-specific MCMC states. - """ + "Maps parameters to their sampler-specific MCMC states." mcmc_states::StateNT - """ - Maps parameters to their sizes. - """ + "Maps parameters to their sizes." variable_sizes::SizeNT end struct GibbsTransition{ValuesNT<:NamedTuple} - """ - Realizations of the parameters, this is considered a "sample" in the MCMC chain. - """ + "Realizations of the parameters, this is considered a \"sample\" in the MCMC chain." values::ValuesNT end @@ -95,7 +87,7 @@ Update the trace with the values from the MCMC states of the sub-problems. function update_trace(trace::NamedTuple, gibbs_state::GibbsState) for parameter_variable in keys(gibbs_state.mcmc_states) sub_state = gibbs_state.mcmc_states[parameter_variable] - sub_state_params = vec(sub_state) + sub_state_params = Base.vec(sub_state) unflattened_sub_state_params = unflatten( sub_state_params, NamedTuple{(parameter_variable,)}(( @@ -115,21 +107,19 @@ function AbstractMCMC.step( initial_params::NamedTuple, kwargs..., ) - if Set(keys(initial_params)) != Set(sampler.parameter_names) + if Set(keys(initial_params)) != Set(keys(sampler.sampler_map)) throw( ArgumentError( - "initial_params must contain all parameters in the model, expected $(sampler.parameter_names), got $(keys(initial_params))", + "initial_params must contain all parameters in the model, expected $(keys(sampler.sampler_map)), got $(keys(initial_params))", ), ) end - mcmc_states = Dict{Symbol,Any}() - variable_sizes = Dict{Symbol,Tuple}() - for parameter_variable in sampler.parameter_names + mcmc_states, variable_sizes = map(keys(sampler.sampler_map)) do parameter_variable sub_sampler = sampler.sampler_map[parameter_variable] variables_to_be_conditioned_on = setdiff( - sampler.parameter_names, (parameter_variable,) + keys(sampler.sampler_map), (parameter_variable,) ) conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) @@ -141,7 +131,6 @@ function AbstractMCMC.step( # LogDensityProblems' `logdensity` function expects a single vector of real numbers # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values # and unflatten after the sampling step - variable_sizes[parameter_variable] = Tuple(size(initial_params[parameter_variable])) flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values) sub_state = last( @@ -158,11 +147,13 @@ function AbstractMCMC.step( kwargs..., ), ) - mcmc_states[parameter_variable] = sub_state + (sub_state, Tuple(size(initial_params[parameter_variable]))) end gibbs_state = GibbsState( - initial_params, NamedTuple(mcmc_states), NamedTuple(variable_sizes) + initial_params, + NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states), + NamedTuple{Tuple(keys(sampler.sampler_map))}(variable_sizes), ) trace = update_trace(NamedTuple(), gibbs_state) return GibbsTransition(trace), gibbs_state @@ -176,14 +167,9 @@ function AbstractMCMC.step( args...; kwargs..., ) - trace = gibbs_state.trace - mcmc_states = gibbs_state.mcmc_states - variable_sizes = gibbs_state.variable_sizes + (; trace, mcmc_states, variable_sizes) = gibbs_state - mcmc_states_dict = Dict( - keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)] - ) - for parameter_variable in sampler.parameter_names + mcmc_states = map(keys(sampler.sampler_map)) do parameter_variable sub_sampler = sampler.sampler_map[parameter_variable] sub_state = mcmc_states[parameter_variable] variables_to_be_conditioned_on = setdiff( @@ -196,7 +182,8 @@ function AbstractMCMC.step( logdensity_model.logdensity, conditioning_variables_values ) - _, sub_state = AbstractMCMC.logdensity_and_state(cond_logdensity, sub_state) + logp = LogDensityProblems.logdensity_and_state(cond_logdensity, sub_state) + sub_state = constructorof(typeof(sub_state))(; logp=logp) sub_state = last( AbstractMCMC.step( rng, @@ -207,12 +194,10 @@ function AbstractMCMC.step( kwargs..., ), ) - mcmc_states_dict[parameter_variable] = sub_state trace = update_trace(trace, gibbs_state) + sub_state end + mcmc_states = NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states) - mcmc_states = NamedTuple{Tuple(keys(mcmc_states_dict))}( - Tuple([mcmc_states_dict[k] for k in keys(mcmc_states_dict)]) - ) return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) end diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl index fae0361..24f1522 100644 --- a/test/gibbs_example/mh.jl +++ b/test/gibbs_example/mh.jl @@ -24,12 +24,12 @@ function Base.vec(state::MHState) return state.params end -struct RandomWalkMH <: AbstractMCMC.AbstractSampler - σ::Float64 +struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler + σ::T end -struct IndependentMH <: AbstractMCMC.AbstractSampler - proposal_dist::Distributions.Distribution +struct IndependentMH{T} <: AbstractMCMC.AbstractSampler + proposal_dist::T end function AbstractMCMC.step(