From df414204484368a4eb701a0fa91ddb999a971012 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 14 Oct 2024 13:40:41 +0100 Subject: [PATCH] Fix treatment of gid in merge(::Metadata) --- src/varinfo.jl | 8 +++++++- test/varinfo.jl | 13 +++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 2670397d9..dd4e3cab2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -392,7 +392,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -422,6 +422,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) dist_right = getdist(metadata_right, vn) # Give precedence to `metadata_right`. push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders`: giving precedence to `metadata_right` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -441,6 +443,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_left = getdist(metadata_left, vn) push!(dists, dist_left) + gid = metadata_left.gids[getidx(metadata_left, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_left, vn)) # `flags` @@ -459,6 +463,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_right = getdist(metadata_right, vn) push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_right, vn)) # `flags` diff --git a/test/varinfo.jl b/test/varinfo.jl index 6a3d8d2bc..e0692f275 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -645,6 +645,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal end + + # The below used to error, testing to avoid regression. + @testset "merge gids" begin + gidset_left = Set([Selector(1)]) + vi_left = VarInfo() + vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) + gidset_right = Set([Selector(2)]) + vi_right = VarInfo() + vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) + varinfo_merged = merge(vi_left, vi_right) + @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left + @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right + end end @testset "VarInfo with selectors" begin