From 679989af47ed861d67639198df7f38ba50fe262f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 13:22:13 +0100 Subject: [PATCH] Update the currently buggy and incorrect tilde overloads in `mh.jl` --- src/mcmc/mh.jl | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index cfbbc70eb..67c7e3612 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -442,42 +442,45 @@ end #### #### Compiler interface, i.e. tilde operators. #### -function DynamicPPL.assume(rng, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi) +function DynamicPPL.assume( + rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi +) + # Just defer to `SampleFromPrior`. + retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) + # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call. DynamicPPL.updategid!(vi, vn, spl) - r = vi[vn] - return r, logpdf_with_trans(dist, r, istrans(vi, vn)), vi + # Return. + return retval end function DynamicPPL.dot_assume( rng, spl::Sampler{<:MH}, dist::MultivariateDistribution, - vn::VarName, + vns::AbstractVector{<:VarName}, var::AbstractMatrix, - vi, + vi::AbstractVarInfo, ) - @assert dim(dist) == size(var, 1) - getvn = i -> VarName(vn, vn.indexing * "[:,$i]") - vns = getvn.(1:size(var, 2)) - DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] - var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))), vi + # Just defer to `SampleFromPrior`. + retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dist, vns[1], var, vi) + # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call. + DynamicPPL.updategid!.((vi,), vns, (spl,)) + # Return. + return retval end function DynamicPPL.dot_assume( rng, spl::Sampler{<:MH}, dists::Union{Distribution,AbstractArray{<:Distribution}}, - vn::VarName, + vns::AbstractArray{<:VarName}, var::AbstractArray, - vi, + vi::AbstractVarInfo, ) - getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") - vns = getvn.(CartesianIndices(var)) - DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) - var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))), vi + # Just defer to `SampleFromPrior`. + retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dists, vns, var, vi) + # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call. + DynamicPPL.updategid!.((vi,), vns, (spl,)) + return retval end function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)