Skip to content

Commit

Permalink
Variable naming, destructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Jan 14, 2025
1 parent c44d81a commit aacb4ed
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,49 +429,49 @@ recursively on the remaining samplers, until no samplers remain. Return the glob
and a tuple of initial states for all component samplers.
"""
function gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi, states=(); initial_params=nothing, kwargs...
rng, model, varname_tuples, samplers, vi, states=(); initial_params=nothing, kwargs...
)
# End recursion
if isempty(varnames) && isempty(samplers)
if isempty(varname_tuples) && isempty(samplers)
return vi, states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
varnames, varname_tuples_tail... = varname_tuples
sampler, samplers_tail... = samplers

# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames_local)[:]
DynamicPPL.subset(vi, varnames)[:]
end

# Construct the conditioned model.
model_local, context_local = make_conditional(model, varnames_local, vi)
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step.
_, new_state_local = AbstractMCMC.step(
# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(
rng,
model_local,
sampler_local;
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context_local))
vi = merge(vi, get_global_varinfo(context))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)

states = (states..., new_state_local)
states = (states..., new_state)
return gibbs_initialstep_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
varname_tuples_tail,
samplers_tail,
vi,
states;
initial_params=initial_params,
Expand Down Expand Up @@ -624,26 +624,26 @@ function on the tail, until there are no more samplers left.
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varnames,
varname_tuples,
samplers,
states,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varnames) && isempty(samplers) && isempty(states)
if isempty(varname_tuples) && isempty(samplers) && isempty(states)
return global_vi, new_states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
state_local = first(states)
varnames, varname_tuples_tail... = varname_tuples
sampler, samplers_tail... = samplers
state, states_tail... = states

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)
varinfo_local = match_linking!!(varinfo_local, state_local, model)
conditioned_model, context = make_conditional(model, varnames, global_vi)
vi = subset(global_vi, varnames)
vi = match_linking!!(vi, state, model)

# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
# sampled by other samplers, we don't need to `setparams`, but could rather simply
Expand All @@ -654,25 +654,25 @@ function gibbs_step_recursive(
# going to be a significant expense anyway.
# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
state = setparams_varinfo!!(
conditioned_model, sampler, state, vi
)

# Take a step with the local sampler.
new_state_local = last(
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
new_state = last(
AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...)
)

new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))

new_states = (new_states..., new_state_local)
new_states = (new_states..., new_state)
return gibbs_step_recursive(
rng,
model,
varnames[2:end],
varname_tuples[2:end],
samplers[2:end],
states[2:end],
new_global_vi,
Expand Down

0 comments on commit aacb4ed

Please sign in to comment.