Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Gibbs constructors #2456

Merged
merged 12 commits into from
Jan 8, 2025
6 changes: 2 additions & 4 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely.

The new Gibbs sampler supports the same user-facing interface as the old one. However, given
that the internals of it having been completely rewritten in a very different manner, there
may be accidental breakage that we haven't anticipated. Please report any you find.
The new Gibbs sampler currently supports the same user-facing interface as the old one, but the old constructors have been deprecated, and will be removed in the future. Also, given that the internals have been completely rewritten in a very different manner, there may be accidental breakage that we haven't anticipated. Please report any you find.

`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking.

The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables thereof to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable.

Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`.

Expand Down
71 changes: 42 additions & 29 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,40 @@
end
set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0))

to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)]
# Any other value is assumed to be an iterable of VarNames and Symbols.
to_varname_list(t) = collect(map(VarName, t))

"""
Gibbs

A type representing a Gibbs sampler.

# Constructors

`Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single
variable name per sampler, one can also give an iterable of variables, all of which are
sampled by the same component sampler.

Each variable name can be given as either a `Symbol` or a `VarName`.

Some examples of valid constructors are:
```julia
Gibbs(:x => NUTS(), :y => MH())
Gibbs(@varname(x) => NUTS(), @varname(y) => MH())
Gibbs((@varname(x), :y) => NUTS(), :z => MH())
```

Currently only variable names without indexing are supported, so for instance
`Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future.

# Fields
$(TYPEDFIELDS)
"""
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
struct Gibbs{V,A} <: InferenceAlgorithm
struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <:
InferenceAlgorithm
# TODO(mhauru) Revisit whether A should have a fixed element type once
# InferenceAlgorithm/Sampler types have been cleaned up.
"varnames representing variables for each sampler"
varnames::V
"samplers for each entry in `varnames`"
Expand All @@ -310,40 +335,30 @@
if length(varnames) != length(samplers)
throw(ArgumentError("Number of varnames and samplers must match."))
end

for spl in samplers
if !isgibbscomponent(spl)
msg = "All samplers must be valid Gibbs components, $(spl) is not."
throw(ArgumentError(msg))
end
end
return new{typeof(varnames),typeof(samplers)}(varnames, samplers)
end
end

to_varname(vn::VarName) = vn
to_varname(s::Symbol) = VarName{s}()
# Any other value is assumed to be an iterable.
to_varname(t) = map(to_varname, collect(t))

# NamedTuple
Gibbs(; algs...) = Gibbs(NamedTuple(algs))
function Gibbs(algs::NamedTuple)
return Gibbs(map(to_varname, keys(algs)), map(set_selector ∘ drop_space, values(algs)))
# Ensure that samplers have the same selector, and that varnames are lists of
# VarNames.
samplers = tuple(map(set_selector ∘ drop_space, samplers)...)
varnames = tuple(map(to_varname_list, varnames)...)
return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers)
end
end

# AbstractDict
function Gibbs(algs::AbstractDict)
return Gibbs(
map(to_varname, collect(keys(algs))), map(set_selector ∘ drop_space, values(algs))
)
end
function Gibbs(algs::Pair...)
return Gibbs(map(to_varname ∘ first, algs), map(set_selector ∘ drop_space ∘ last, algs))
return Gibbs(map(first, algs), map(last, algs))
end

# The below two constructors only provide backwards compatibility with the constructor of
# the old Gibbs sampler. They are deprecated and will be removed in the future.
function Gibbs(algs::InferenceAlgorithm...)
function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...)
algs = [alg1, other_algs...]

Check warning on line 361 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L360-L361

Added lines #L360 - L361 were not covered by tests
varnames = map(algs) do alg
space = getspace(alg)
if (space isa VarName)
Expand All @@ -365,7 +380,11 @@
return Gibbs(varnames, map(set_selector ∘ drop_space, algs))
end

function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...)
function Gibbs(

Check warning on line 383 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L383

Added line #L383 was not covered by tests
alg_with_iters1::Tuple{<:InferenceAlgorithm,Int},
other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...,
)
algs_with_iters = [alg_with_iters1, other_algs_with_iters...]

Check warning on line 387 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L387

Added line #L387 was not covered by tests
algs = Iterators.map(first, algs_with_iters)
iters = Iterators.map(last, algs_with_iters)
algs_duplicated = Iterators.flatten((
Expand All @@ -384,11 +403,6 @@
states::S
end

_maybevec(x) = vec(x) # assume it's iterable
_maybevec(x::Tuple) = [x...]
_maybevec(x::VarName) = [x]
_maybevec(x::Symbol) = [x]

varinfo(state::GibbsState) = state.vi

function DynamicPPL.initialstep(
Expand All @@ -412,7 +426,6 @@
# Initialise each component sampler in turn, collect all their states.
states = []
for (varnames_local, sampler_local) in zip(varnames, samplers)
varnames_local = _maybevec(varnames_local)
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
Expand Down Expand Up @@ -463,7 +476,7 @@
# Take the inner step.
sampler_local = samplers[index]
state_local = states[index]
varnames_local = _maybevec(varnames[index])
varnames_local = varnames[index]
vi, new_state_local = gibbs_step_inner(
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
)
Expand Down
6 changes: 3 additions & 3 deletions test/dynamicppl/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const gdemo_default = gdemo_d()

smc = SMC()
pg = PG(10)
gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10))
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))

chn_s = sample(testbb(obs), smc, 1000)
chn_p = sample(testbb(obs), pg, 2000)
Expand All @@ -81,7 +81,7 @@ const gdemo_default = gdemo_d()
return s, m
end

gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8))
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8))
chain = sample(fggibbstest(xs), gibbs, 2)
end
@testset "new grammar" begin
Expand Down Expand Up @@ -177,7 +177,7 @@ const gdemo_default = gdemo_d()
end

@testset "sample" begin
alg = Gibbs(; m=HMC(0.2, 3), s=PG(10))
alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10))
chn = sample(gdemo_default, alg, 1000)
end

Expand Down
Loading
Loading