From ee0e061e148bf8294348fd8112f7fec94342b5a0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 4 Dec 2024 14:46:59 +0000 Subject: [PATCH] Remove varinfo.jl, everything's merged back into DPPL --- test/dynamicppl/varinfo.jl | 205 ------------------------------------- test/runtests.jl | 2 - 2 files changed, 207 deletions(-) delete mode 100644 test/dynamicppl/varinfo.jl diff --git a/test/dynamicppl/varinfo.jl b/test/dynamicppl/varinfo.jl deleted file mode 100644 index ae2c70e53..000000000 --- a/test/dynamicppl/varinfo.jl +++ /dev/null @@ -1,205 +0,0 @@ -module DynamicPPLVarInfoTests - -using ..NumericalTests: check_numerical -using AbstractPPL: VarName -using BangBang: push!!, empty!! -using LinearAlgebra: I -using Test: @testset, @test -using Turing - -@testset "varinfo.jl" begin - # Declare empty model to make the Sampler constructor work. - @model empty_model() = begin - x = 1 - end - - function randr( - vi::DynamicPPL.VarInfo, - vn::VarName, - dist::Distribution, - spl::DynamicPPL.Sampler, - count::Bool=false, - ) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist, spl) - r - elseif DynamicPPL.is_flagged(vi, vn, "del") - DynamicPPL.unset_flag!(vi, vn, "del") - r = rand(dist) - vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) - r - else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) - vi[vn] - end - end - - @testset "orders" begin - csym = gensym() # unique per model - vn_z1 = @varname z[1] - vn_z2 = @varname z[2] - vn_z3 = @varname z[3] - vn_z4 = @varname z[4] - vn_a1 = @varname a[1] - vn_a2 = @varname a[2] - vn_b = @varname b - - vi = DynamicPPL.VarInfo() - dists = [Categorical([0.7, 0.3]), Normal()] - - spl1 = DynamicPPL.Sampler(PG(5), empty_model()) - spl2 = DynamicPPL.Sampler(PG(5), empty_model()) - - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] - @test DynamicPPL.get_num_produce(vi) == 3 - - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) - @test DynamicPPL.is_flagged(vi, vn_z1, "del") - @test DynamicPPL.is_flagged(vi, vn_a1, "del") - @test DynamicPPL.is_flagged(vi, vn_z2, "del") - @test DynamicPPL.is_flagged(vi, vn_a2, "del") - @test DynamicPPL.is_flagged(vi, vn_z3, "del") - - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] - @test DynamicPPL.get_num_produce(vi) == 3 - - vi = empty!!(DynamicPPL.TypedVarInfo(vi)) - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] - @test DynamicPPL.get_num_produce(vi) == 3 - - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) - @test DynamicPPL.is_flagged(vi, vn_z1, "del") - @test DynamicPPL.is_flagged(vi, vn_a1, "del") - @test DynamicPPL.is_flagged(vi, vn_z2, "del") - @test DynamicPPL.is_flagged(vi, vn_a2, "del") - @test DynamicPPL.is_flagged(vi, vn_z3, "del") - - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] - @test DynamicPPL.get_num_produce(vi) == 3 - end - - @testset "varname" begin - @model function mat_name_test() - p = Array{Any}(undef, 2, 2) - for i in 1:2, j in 1:2 - p[i, j] ~ Normal(0, 1) - end - return p - end - chain = sample(mat_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1, 1]"], [0]; atol=0.25) - - @model function marr_name_test() - p = Array{Array{Any}}(undef, 2) - p[1] = Array{Any}(undef, 2) - p[2] = Array{Any}(undef, 2) - for i in 1:2, j in 1:2 - p[i][j] ~ Normal(0, 1) - end - return p - end - - chain = sample(marr_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1][1]"], [0]; atol=0.25) - end - - @testset "varinfo" begin - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - function test_varinfo!(vi) - spl2 = DynamicPPL.Sampler(PG(5, :w, :u), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = DynamicPPL.Sampler(PG(5, :x, :y, :z), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = DynamicPPL.VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) - end -end - -end diff --git a/test/runtests.jl b/test/runtests.jl index ca4d260a6..d3d0e45ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,8 +87,6 @@ end @testset "DynamicPPL integration" begin @timeit_include("dynamicppl/compiler.jl") - @timeit_include("dynamicppl/model.jl") - @timeit_include("dynamicppl/varinfo.jl") end @testset "utilities" begin