From 600d36cb556ccfc906833ff6363fa69caa7e960d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Oct 2024 09:28:34 -0400 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/api.md | 57 +++++++++++---------------------------------- src/AbstractMCMC.jl | 32 +++++++++---------------- 2 files changed, 25 insertions(+), 64 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 648a87b8..f3ce4271 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -112,8 +112,8 @@ AbstractMCMC.chainsstack To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: ```@docs -AbstractMCMC.realize -AbstractMCMC.realize!! +AbstractMCMC.getparams +AbstractMCMC.setparams!! ``` and optionally ```@docs @@ -125,7 +125,7 @@ These methods can also be useful for implementing samplers which wraps some inne In a `MixtureSampler` we need two things: - `components`: collection of samplers. -- `weights`: collection of weights representing the probability of chosing the corresponding sampler. +- `weights`: collection of weights representing the probability of choosing the corresponding sampler. ```julia struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler @@ -136,7 +136,6 @@ end To implement the state, we need to keep track of a couple of things: - `index`: the index of the sampler used in this `step`. -- `transition`: the transition resulting from this `step`. - `states`: the current states of _all_ the components. Two aspects of this might seem a bit strange: 1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. @@ -146,11 +145,9 @@ The reason for (1) is that lots of samplers keep track of more than just the pre For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`. -Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`. ```julia -struct MixtureState{T,S} +struct MixtureState{S} index::Int - transition::T states::S end ``` @@ -162,15 +159,16 @@ X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1}) \end{aligned} ``` where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler. -[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. +[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: ```julia # Update the corresponding state, i.e. `state.states[i]`, using # the state and transition from the previous iteration. -state_current = AbstractMCMC.updatestate!!( - state.states[i], state.states[i_prev], state.transition +state_current = AbstractMCMC.setparams!!( + state.states[i], + AbstractMCMC.getparams(state.states[i_prev]), ) # Take a `step` for this sampler using the updated state. @@ -191,8 +189,9 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # Update the corresponding state, i.e. `state.states[i]`, using # the state and transition from the previous iteration. i_prev = state.index - state_current = AbstractMCMC.updatestate!!( - model, state.states[i], state.states[i_prev], state.transition + state_current = AbstractMCMC.setparams!!( + state.states[i], + AbstractMCMC.getparams(state.states[i_prev]), ) # Take a `step` for this sampler using the updated state. @@ -217,7 +216,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt end # Create the new `MixtureState`. - state_new = MixtureState(i, transition, states_new) + state_new = MixtureState(i, states_new) return transition, state_new end @@ -239,20 +238,14 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # Extract states. states = map(last, transitions_and_states) # Create new `MixtureState`. - state = MixtureState(i, transition, states) + state = MixtureState(i, states) return transition, state end ``` -Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `realize` and `realize!!`: +Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`. -```julia -function AbstractMCMC.updatestate!!(model, ::AdvancedMH.Transition, state_prev::AdvancedMH.Transition) - # Let's `deepcopy` just to be certain. - return deepcopy(state_prev) -end -``` To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do @@ -263,25 +256,3 @@ while ... transition, state = AbstractMCMC.step(rng, model, sampler, state) end ``` - -As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`: - -```julia -struct ManyModels{M} <: AbstractMCMC.AbstractModel - models::M -end - -Base.getindex(model::ManyModels, I...) = model.models[I...] -``` - -Then the above `step` would just extract the `model` corresponding to the current sampler: - -```julia -# Take a `step` for this sampler using the updated state. -transition, state_current = AbstractMCMC.step( - rng, model[i], sampler_current, state_current; - kwargs... -) -``` - -This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`. diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 07960440..687d8f83 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,35 +80,25 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr struct MCMCSerial <: AbstractMCMCEnsemble end """ - updatestate!!(model, state, transition_prev[, state_prev]) + getparams(state[; kwargs...]) -Return new instance of `state` using information from `model`, `transition_prev` and, optionally, `state_prev`. - -Defaults to `realize!!(state, realize(transition_prev))`. +Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`. """ -function updatestate!!(model, state, transition_prev, state_prev) - return updatestate!!(state, transition_prev) -end -updatestate!!(model, state, transition) = realize!!(state, realize(transition)) +function getparams end """ - realize!!(state, realization) - -Update the realization of the `state` with `realization` and return it. + setparams!!(state, params) -If `state` can be updated in-place, it is expected that this function returns `state` with updated -realize. Otherwise a new `state` object with the new `realization` is returned. -""" -function realize!! end +Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`. -""" - realize(transition) +This function should follow the `BangBang` interface: mutate `state` in-place if possible and +return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters. -Return the realization of the random variables present in `transition`. +Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another +word, the sampler should implement a consistent transformation between its internal representation +and the vector representation of the parameter values. """ -function realize end - - +function setparams!! end include("samplingstats.jl") include("logging.jl") include("interface.jl")