Skip to content

Commit

Permalink
For VarInfo, fix merge and allow push!!ing new Symbols (#690)
Browse files Browse the repository at this point in the history
* Fix treatment of gid in merge(::Metadata)

* Allowing pushing new symbols to TypedVarInfo

* Bump patch version to 0.30.1
  • Loading branch information
mhauru committed Oct 17, 2024
1 parent 30c10c2 commit 4650230
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.28.5"
version = "0.28.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
53 changes: 37 additions & 16 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
ranges = Vector{UnitRange{Int}}()
vals = T[]
dists = D[]
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
gids = Set{Selector}[]
orders = Int[]
flags = Dict{String,BitVector}()
# Initialize the `flags`.
Expand Down Expand Up @@ -416,6 +416,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
dist_right = getdist(metadata_right, vn)
# Give precedence to `metadata_right`.
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`: giving precedence to `metadata_right`
push!(orders, getorder(metadata_right, vn))
# `flags`
Expand All @@ -435,6 +437,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
# `dists`
dist_left = getdist(metadata_left, vn)
push!(dists, dist_left)
gid = metadata_left.gids[getidx(metadata_left, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_left, vn))
# `flags`
Expand All @@ -453,6 +457,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
# `dists`
dist_right = getdist(metadata_right, vn)
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_right, vn))
# `flags`
Expand Down Expand Up @@ -1598,25 +1604,40 @@ function BangBang.push!!(
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
)
if vi isa UntypedVarInfo
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
elseif vi isa TypedVarInfo
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
end

val = vectorize(dist, r)

meta = getmetadata(vi, vn)
meta.idcs[vn] = length(meta.idcs) + 1
push!(meta.vns, vn)
l = length(meta.vals)
n = length(val)
push!(meta.ranges, (l + 1):(l + n))
append!(meta.vals, val)
push!(meta.dists, dist)
push!(meta.gids, gidset)
push!(meta.orders, get_num_produce(vi))
push!(meta.flags["del"], false)
push!(meta.flags["trans"], false)
sym = getsym(vn)
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
# The NamedTuple doesn't have an entry for this variable, let's add one.
md = Metadata(
Dict(vn => 1),
[vn],
[1:length(val)],
val,
[dist],
[gidset],
[get_num_produce(vi)],
Dict{String,BitVector}("trans" => [false], "del" => [false]),
)
vi = Accessors.@set vi.metadata[sym] = md
else
meta = getmetadata(vi, vn)
meta.idcs[vn] = length(meta.idcs) + 1
push!(meta.vns, vn)
l = length(meta.vals)
n = length(val)
push!(meta.ranges, (l + 1):(l + n))
append!(meta.vals, val)
push!(meta.dists, dist)
push!(meta.gids, gidset)
push!(meta.orders, get_num_produce(vi))
push!(meta.flags["del"], false)
push!(meta.flags["trans"], false)
end

return vi
end
Expand Down
25 changes: 25 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
test_varinfo!(vi)
test_varinfo!(empty!!(TypedVarInfo(vi)))
end

@testset "push!! to TypedVarInfo" begin
vn_x = @varname x
vn_y = @varname y
untyped_vi = VarInfo()
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
typed_vi = TypedVarInfo(untyped_vi)
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
@test typed_vi[vn_x] == 1.0
@test typed_vi[vn_y] == 2.0
end

@testset "setgid!" begin
vi = VarInfo()
meta = vi.metadata
Expand Down Expand Up @@ -645,6 +657,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
@test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal
end

# The below used to error, testing to avoid regression.
@testset "merge gids" begin
gidset_left = Set([Selector(1)])
vi_left = VarInfo()
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
gidset_right = Set([Selector(2)])
vi_right = VarInfo()
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
varinfo_merged = merge(vi_left, vi_right)
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
end
end

@testset "VarInfo with selectors" begin
Expand Down

0 comments on commit 4650230

Please sign in to comment.