diff --git a/Project.toml b/Project.toml index 8e2cb92e2..98459eb0b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.28.5" +version = "0.28.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/varinfo.jl b/src/varinfo.jl index 3f7e335aa..6278d260f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -386,7 +386,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -416,6 +416,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) dist_right = getdist(metadata_right, vn) # Give precedence to `metadata_right`. push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders`: giving precedence to `metadata_right` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -435,6 +437,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_left = getdist(metadata_left, vn) push!(dists, dist_left) + gid = metadata_left.gids[getidx(metadata_left, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_left, vn)) # `flags` @@ -453,6 +457,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_right = getdist(metadata_right, vn) push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -1598,25 +1604,40 @@ function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end val = vectorize(dist, r) - - meta = getmetadata(vi, vn) - meta.idcs[vn] = length(meta.idcs) + 1 - push!(meta.vns, vn) - l = length(meta.vals) - n = length(val) - push!(meta.ranges, (l + 1):(l + n)) - append!(meta.vals, val) - push!(meta.dists, dist) - push!(meta.gids, gidset) - push!(meta.orders, get_num_produce(vi)) - push!(meta.flags["del"], false) - push!(meta.flags["trans"], false) + sym = getsym(vn) + if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + # The NamedTuple doesn't have an entry for this variable, let's add one. + md = Metadata( + Dict(vn => 1), + [vn], + [1:length(val)], + val, + [dist], + [gidset], + [get_num_produce(vi)], + Dict{String,BitVector}("trans" => [false], "del" => [false]), + ) + vi = Accessors.@set vi.metadata[sym] = md + else + meta = getmetadata(vi, vn) + meta.idcs[vn] = length(meta.idcs) + 1 + push!(meta.vns, vn) + l = length(meta.vals) + n = length(val) + push!(meta.ranges, (l + 1):(l + n)) + append!(meta.vals, val) + push!(meta.dists, dist) + push!(meta.gids, gidset) + push!(meta.orders, get_num_produce(vi)) + push!(meta.flags["del"], false) + push!(meta.flags["trans"], false) + end return vi end diff --git a/test/varinfo.jl b/test/varinfo.jl index ca87bf571..ff0e7235f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end + + @testset "push!! to TypedVarInfo" begin + vn_x = @varname x + vn_y = @varname y + untyped_vi = VarInfo() + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + typed_vi = TypedVarInfo(untyped_vi) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + @test typed_vi[vn_x] == 1.0 + @test typed_vi[vn_y] == 2.0 + end + @testset "setgid!" begin vi = VarInfo() meta = vi.metadata @@ -645,6 +657,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal end + + # The below used to error, testing to avoid regression. + @testset "merge gids" begin + gidset_left = Set([Selector(1)]) + vi_left = VarInfo() + vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) + gidset_right = Set([Selector(2)]) + vi_right = VarInfo() + vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) + varinfo_merged = merge(vi_left, vi_right) + @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left + @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right + end end @testset "VarInfo with selectors" begin