From dfa63384d4e0d57cfee735358320ded829d9fc20 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 14 Oct 2024 17:13:50 +0100 Subject: [PATCH] Allowing pushing new symbols to TypedVarInfo --- src/varinfo.jl | 26 +++++++++++++++++++++----- test/varinfo.jl | 12 ++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index dd4e3cab2..6e3cd5714 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1624,13 +1624,30 @@ function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end - meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, gidset, get_num_produce(vi)) + sym = getsym(vn) + if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + # The NamedTuple doesn't have an entry for this variable, let's add one. + val = tovec(r) + md = Metadata( + Dict(vn => 1), + [vn], + [1:length(val)], + val, + [dist], + [gidset], + [get_num_produce(vi)], + Dict{String,BitVector}("trans" => [false], "del" => [false]), + ) + vi = Accessors.@set vi.metadata[sym] = md + else + meta = getmetadata(vi, vn) + push!(meta, vn, r, dist, gidset, get_num_produce(vi)) + end return vi end @@ -1648,7 +1665,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) - return meta end diff --git a/test/varinfo.jl b/test/varinfo.jl index e0692f275..e56fa55e4 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end + + @testset "push!! to TypedVarInfo" begin + vn_x = @varname x + vn_y = @varname y + untyped_vi = VarInfo() + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + typed_vi = TypedVarInfo(untyped_vi) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + @test typed_vi[vn_x] == 1.0 + @test typed_vi[vn_y] == 2.0 + end + @testset "setgid!" begin vi = VarInfo() meta = vi.metadata