diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 462012676..51c7a1ed5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi) return left, acclogp_observe!!(context, vi, logp) end -function assume(rng, spl::Sampler, dist) +function assume(rng::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -291,14 +291,18 @@ end function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) +function dot_tilde_assume( + ::IsLeaf, rng::Random.AbstractRNG, ::AbstractContext, sampler, right, left, vns, vi +) return dot_assume(rng, sampler, right, vns, left, vi) end function dot_tilde_assume(::IsParent, context::AbstractContext, args...) return dot_tilde_assume(childcontext(context), args...) end -function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...) +function dot_tilde_assume( + ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args... +) return dot_tilde_assume(rng, childcontext(context), args...) end @@ -371,7 +375,7 @@ function dot_assume( end function dot_assume( - rng, + rng::Random.AbstractRNG, spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, vns::AbstractVector{<:VarName}, @@ -404,7 +408,7 @@ function dot_assume( end function dot_assume( - rng, + rng::Random.AbstractRNG, spl::Union{SampleFromPrior,SampleFromUniform}, dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, @@ -416,7 +420,9 @@ function dot_assume( lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) return r, lp, vi end -function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) +function dot_assume( + rng::Random.AbstractRNG, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any +) return error( "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" ) @@ -436,7 +442,7 @@ function _maybe_invlink_broadcast(vi, vn, dist) end function get_and_set_val!( - rng, + rng::Random.AbstractRNG, vi::VarInfoOrThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, @@ -478,7 +484,7 @@ function get_and_set_val!( end function get_and_set_val!( - rng, + rng::Random.AbstractRNG, vi::VarInfoOrThreadSafeVarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}},