Skip to content

Commit

Permalink
Rework Gibbs constructors (#2456)
Browse files Browse the repository at this point in the history
* Rework Gibbs constructors, and remove the dead test/experimental/gibbs.jl

* Update HISTORY.md

* Clarify docstring

* Remove unnecessary _maybecollect in gibbs.jl

* Fix a bug

* Fix more Gibbs constructors in tests

* Improve HISTORY.md note

Co-authored-by: Penelope Yong <[email protected]>

* Apply proposals from code review

* Add type bounds to Gibbs type parameters

* Style improvements to gibbs.jl

* Fix method ambiguity

* Modify type signature of Gibbs

---------

Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
mhauru and penelopeysm authored Jan 8, 2025
1 parent 0c3d3d0 commit 7d6f8ed
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 355 deletions.
6 changes: 2 additions & 4 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely.

The new Gibbs sampler supports the same user-facing interface as the old one. However, given
that the internals of it having been completely rewritten in a very different manner, there
may be accidental breakage that we haven't anticipated. Please report any you find.
The new Gibbs sampler currently supports the same user-facing interface as the old one, but the old constructors have been deprecated, and will be removed in the future. Also, given that the internals have been completely rewritten in a very different manner, there may be accidental breakage that we haven't anticipated. Please report any you find.

`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking.

The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables thereof to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable.

Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`.

Expand Down
71 changes: 42 additions & 29 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,40 @@ function set_selector(x::RepeatSampler)
end
set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0))

to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)]
# Any other value is assumed to be an iterable of VarNames and Symbols.
to_varname_list(t) = collect(map(VarName, t))

"""
Gibbs
A type representing a Gibbs sampler.
# Constructors
`Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single
variable name per sampler, one can also give an iterable of variables, all of which are
sampled by the same component sampler.
Each variable name can be given as either a `Symbol` or a `VarName`.
Some examples of valid constructors are:
```julia
Gibbs(:x => NUTS(), :y => MH())
Gibbs(@varname(x) => NUTS(), @varname(y) => MH())
Gibbs((@varname(x), :y) => NUTS(), :z => MH())
```
Currently only variable names without indexing are supported, so for instance
`Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future.
# Fields
$(TYPEDFIELDS)
"""
struct Gibbs{V,A} <: InferenceAlgorithm
struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <:
InferenceAlgorithm
# TODO(mhauru) Revisit whether A should have a fixed element type once
# InferenceAlgorithm/Sampler types have been cleaned up.
"varnames representing variables for each sampler"
varnames::V
"samplers for each entry in `varnames`"
Expand All @@ -310,40 +335,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm
if length(varnames) != length(samplers)
throw(ArgumentError("Number of varnames and samplers must match."))
end

for spl in samplers
if !isgibbscomponent(spl)
msg = "All samplers must be valid Gibbs components, $(spl) is not."
throw(ArgumentError(msg))
end
end
return new{typeof(varnames),typeof(samplers)}(varnames, samplers)
end
end

to_varname(vn::VarName) = vn
to_varname(s::Symbol) = VarName{s}()
# Any other value is assumed to be an iterable.
to_varname(t) = map(to_varname, collect(t))

# NamedTuple
Gibbs(; algs...) = Gibbs(NamedTuple(algs))
function Gibbs(algs::NamedTuple)
return Gibbs(map(to_varname, keys(algs)), map(set_selector drop_space, values(algs)))
# Ensure that samplers have the same selector, and that varnames are lists of
# VarNames.
samplers = tuple(map(set_selector drop_space, samplers)...)
varnames = tuple(map(to_varname_list, varnames)...)
return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers)
end
end

# AbstractDict
function Gibbs(algs::AbstractDict)
return Gibbs(
map(to_varname, collect(keys(algs))), map(set_selector drop_space, values(algs))
)
end
function Gibbs(algs::Pair...)
return Gibbs(map(to_varname first, algs), map(set_selector drop_space last, algs))
return Gibbs(map(first, algs), map(last, algs))
end

# The below two constructors only provide backwards compatibility with the constructor of
# the old Gibbs sampler. They are deprecated and will be removed in the future.
function Gibbs(algs::InferenceAlgorithm...)
function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...)
algs = [alg1, other_algs...]
varnames = map(algs) do alg
space = getspace(alg)
if (space isa VarName)
Expand All @@ -365,7 +380,11 @@ function Gibbs(algs::InferenceAlgorithm...)
return Gibbs(varnames, map(set_selector drop_space, algs))
end

function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...)
function Gibbs(
alg_with_iters1::Tuple{<:InferenceAlgorithm,Int},
other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...,
)
algs_with_iters = [alg_with_iters1, other_algs_with_iters...]
algs = Iterators.map(first, algs_with_iters)
iters = Iterators.map(last, algs_with_iters)
algs_duplicated = Iterators.flatten((
Expand All @@ -384,11 +403,6 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
states::S
end

_maybevec(x) = vec(x) # assume it's iterable
_maybevec(x::Tuple) = [x...]
_maybevec(x::VarName) = [x]
_maybevec(x::Symbol) = [x]

varinfo(state::GibbsState) = state.vi

function DynamicPPL.initialstep(
Expand All @@ -412,7 +426,6 @@ function DynamicPPL.initialstep(
# Initialise each component sampler in turn, collect all their states.
states = []
for (varnames_local, sampler_local) in zip(varnames, samplers)
varnames_local = _maybevec(varnames_local)
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
Expand Down Expand Up @@ -463,7 +476,7 @@ function AbstractMCMC.step(
# Take the inner step.
sampler_local = samplers[index]
state_local = states[index]
varnames_local = _maybevec(varnames[index])
varnames_local = varnames[index]
vi, new_state_local = gibbs_step_inner(
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
)
Expand Down
6 changes: 3 additions & 3 deletions test/dynamicppl/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const gdemo_default = gdemo_d()

smc = SMC()
pg = PG(10)
gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10))
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))

chn_s = sample(testbb(obs), smc, 1000)
chn_p = sample(testbb(obs), pg, 2000)
Expand All @@ -81,7 +81,7 @@ const gdemo_default = gdemo_d()
return s, m
end

gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8))
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8))
chain = sample(fggibbstest(xs), gibbs, 2)
end
@testset "new grammar" begin
Expand Down Expand Up @@ -177,7 +177,7 @@ const gdemo_default = gdemo_d()
end

@testset "sample" begin
alg = Gibbs(; m=HMC(0.2, 3), s=PG(10))
alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10))
chn = sample(gdemo_default, alg, 1000)
end

Expand Down
Loading

0 comments on commit 7d6f8ed

Please sign in to comment.