Skip to content

Commit

Permalink
Update the currently buggy and incorrect tilde overloads in mh.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 4, 2024
1 parent 452d0d0 commit 679989a
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,42 +442,45 @@ end
####
#### Compiler interface, i.e. tilde operators.
####
function DynamicPPL.assume(rng, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi)
function DynamicPPL.assume(
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
)
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!(vi, vn, spl)
r = vi[vn]
return r, logpdf_with_trans(dist, r, istrans(vi, vn)), vi
# Return.
return retval
end

function DynamicPPL.dot_assume(
rng,
spl::Sampler{<:MH},
dist::MultivariateDistribution,
vn::VarName,
vns::AbstractVector{<:VarName},
var::AbstractMatrix,
vi,
vi::AbstractVarInfo,
)
@assert dim(dist) == size(var, 1)
getvn = i -> VarName(vn, vn.indexing * "[:,$i]")
vns = getvn.(1:size(var, 2))
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
r = vi[vns]
var .= r
return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))), vi
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dist, vns[1], var, vi)

Check warning on line 465 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L465

Added line #L465 was not covered by tests
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!.((vi,), vns, (spl,))

Check warning on line 467 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L467

Added line #L467 was not covered by tests
# Return.
return retval

Check warning on line 469 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L469

Added line #L469 was not covered by tests
end
function DynamicPPL.dot_assume(
rng,
spl::Sampler{<:MH},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vn::VarName,
vns::AbstractArray{<:VarName},
var::AbstractArray,
vi,
vi::AbstractVarInfo,
)
getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]")
vns = getvn.(CartesianIndices(var))
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
r = reshape(vi[vec(vns)], size(var))
var .= r
return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))), vi
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dists, vns, var, vi)

Check warning on line 480 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L480

Added line #L480 was not covered by tests
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!.((vi,), vns, (spl,))
return retval

Check warning on line 483 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L482-L483

Added lines #L482 - L483 were not covered by tests
end

function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
Expand Down

0 comments on commit 679989a

Please sign in to comment.