Skip to content

Commit

Permalink
results is wrong
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Aug 15, 2024
1 parent 55dbab5 commit 67ff8e8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 69 deletions.
26 changes: 15 additions & 11 deletions gibbs_example/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ using OrderedCollections

##

# TODO: introduce some kind of parameter format, for instance, a flattened vector
# then define some kind of function to transform the flattened vector into model's representation

struct Gibbs <: AbstractMCMC.AbstractSampler
sampler_map::OrderedDict
end
Expand Down Expand Up @@ -73,12 +70,12 @@ function AbstractMCMC.step(
cond_val = NamedTuple{Tuple(group_complement)}(
Tuple([vi[g] for g in group_complement])
)
cond_logdensity = condition(logdensity_model.logdensity, cond_val)
sub_state = recompute_logprob!!(cond_logdensity, getparams(sub_state), sub_state)
sub_state = last(
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
condition(logdensity_model.logdensity, cond_val)
),
AbstractMCMC.LogDensityModel(cond_logdensity),
sub_spl,
sub_state,
args...;
Expand All @@ -87,8 +84,8 @@ function AbstractMCMC.step(
)
state.states[group] = sub_state
end
for sub_state in values(state.states)
vi = merge(vi, getparams(sub_state))
for (group, sub_state) in state.states
vi = merge(vi, unflatten(getparams(sub_state), group))
end
return GibbsTransition(vi), GibbsState(vi, state.states)
end
Expand All @@ -103,9 +100,16 @@ samples = sample(
OrderedDict(
(:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
(:w,) => PriorMH(Dirichlet(2, 1.0)),
(, :w) => RWMH(1),
(,) => RWMH(1),
),
),
10000;
100000;
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]),
)
);

z_samples = [sample.values.z for sample in samples][20001:end]
μ_samples = [sample.values.μ for sample in samples][20001:end]
w_samples = [sample.values.w for sample in samples][20001:end]

mean(μ_samples)
mean(w_samples)
71 changes: 13 additions & 58 deletions gibbs_example/gmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,16 @@ function condition(gmm::GMM, conditioned_values::NamedTuple)
return ConditionedGMM(gmm.data, conditioned_values)
end

function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params)
return log_joint(;
μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x
)
end

function _logdensity(gmm::ConditionedGMM{(:z,)}, params)
return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x)
end

function LogDensityProblems.logdensity(
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}},
params_vec::AbstractVector,
)
@assert length(params_vec) == 60
return _logdensity(gmm, (; z=params_vec))
end
function LogDensityProblems.logdensity(
gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector
)
@assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))"
return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4]))
end

function LogDensityProblems.dimension(gmm::GMM)
return 4 + size(gmm.data.x, 1)
end

function LogDensityProblems.dimension(
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}
)
return 4
end

function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)})
return size(gmm.data.x, 1)
function LogDensityProblems.logdensity(gmm::ConditionedGMM{names}, params::AbstractVector) where {names}
if Set(names) == Set([, :w]) # conditioned on μ, w, so params are z
return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x)
elseif Set(names) == Set([:z, :w]) # conditioned on z, w, so params are μ
return log_joint(; μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x)
elseif Set(names) == Set([:z, ]) # conditioned on z, μ, so params are w
return log_joint(; μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x)
else
error("Unsupported conditioning configuration.")
end
end

function LogDensityProblems.capabilities(::GMM)
Expand All @@ -91,41 +65,22 @@ function LogDensityProblems.capabilities(::ConditionedGMM)
end

function flatten(nt::NamedTuple)
if Set(keys(nt)) == Set([, :w])
return vcat(nt.μ, nt.w)
elseif Set(keys(nt)) == Set([:z])
return nt.z
else
error()
end
return only(values(nt))
end

function unflatten(vec::AbstractVector, group::Tuple)
if Set(group) == Set([, :w])
return (; μ=vec[1:2], w=vec[3:4])
elseif Set(group) == Set([:z])
return (; z=vec)
else
error()
end
return NamedTuple((only(group) => vec,))
end

# sampler's states to internal representation
# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?)

# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed
function recompute_logprob!!(gmm::ConditionedGMM, vals, state)
return setlogp!(state, _logdensity(gmm, vals))
return setlogp!!(state, LogDensityProblems.logdensity(gmm, vals))
end

## test using Turing

# data generation

using Distributions
using FillArrays
using LinearAlgebra
using Random

w = [0.5, 0.5]
μ = [-3.5, 0.5]
Expand Down

0 comments on commit 67ff8e8

Please sign in to comment.