diff --git a/.github/workflows/IntegrationTests.yml b/.github/workflows/IntegrationTests.yml index 127dc40d..c823bba1 100644 --- a/.github/workflows/IntegrationTests.yml +++ b/.github/workflows/IntegrationTests.yml @@ -17,6 +17,7 @@ jobs: package: - DynamicHMC - AdvancedHMC + - Turing steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 79e6416e..bbc9aad7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "Pathfinder" uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" authors = ["Seth Axen and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -24,6 +25,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] AbstractDifferentiation = "0.4" +Accessors = "0.1" Distributions = "0.25" Folds = "0.2" ForwardDiff = "0.10" diff --git a/src/Pathfinder.jl b/src/Pathfinder.jl index 20662c0f..495d4871 100644 --- a/src/Pathfinder.jl +++ b/src/Pathfinder.jl @@ -49,6 +49,9 @@ function __init__() Requires.@require AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" begin include("integration/advancedhmc.jl") end + Requires.@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin + include("integration/turing.jl") + end end end diff --git a/src/integration/turing.jl b/src/integration/turing.jl new file mode 100644 index 00000000..0cbc05c4 --- /dev/null +++ b/src/integration/turing.jl @@ -0,0 +1,161 @@ +using .Turing: Turing, DynamicPPL, MCMCChains +using Accessors: Accessors + +# utilities for working with Turing model parameter names using only the DynamicPPL API + +""" + flattened_varnames_list(model::DynamicPPL.Model) -> Vector{Symbol} + +Get a vector of varnames as `Symbol`s with one-to-one correspondence to the +flattened parameter vector. + +```julia +julia> @model function demo() + s ~ Dirac(1) + x = Matrix{Float64}(undef, 2, 4) + x[1, 1] ~ Dirac(2) + x[2, 1] ~ Dirac(3) + x[3] ~ Dirac(4) + y ~ Dirac(5) + x[4] ~ Dirac(6) + x[:, 3] ~ arraydist([Dirac(7), Dirac(8)]) + x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)]) + return s, x, y + end +demo (generic function with 2 methods) + +julia> flattened_varnames_list(demo()) +10-element Vector{Symbol}: + :s + Symbol("x[1,1]") + Symbol("x[2,1]") + Symbol("x[3]") + Symbol("x[4]") + Symbol("x[:,3][1]") + Symbol("x[:,3][2]") + Symbol("x[[2, 1],4][1]") + Symbol("x[[2, 1],4][2]") + :y +``` +""" +function flattened_varnames_list(model::DynamicPPL.Model) + varnames_ranges = varnames_to_ranges(model) + nsyms = maximum(maximum, values(varnames_ranges)) + syms = Vector{Symbol}(undef, nsyms) + for (var_name, range) in varnames_to_ranges(model) + sym = Symbol(var_name) + if length(range) == 1 + syms[range[begin]] = sym + continue + end + for i in eachindex(range) + syms[range[i]] = Symbol("$sym[$i]") + end + end + return syms +end + +# code snippet shared by @torfjelde +""" + varnames_to_ranges(model::DynamicPPL.Model) + varnames_to_ranges(model::DynamicPPL.VarInfo) + varnames_to_ranges(model::DynamicPPL.Metadata) + +Get `Dict` mapping variable names in model to their ranges in a corresponding parameter vector. + +# Examples + +```julia +julia> @model function demo() + s ~ Dirac(1) + x = Matrix{Float64}(undef, 2, 4) + x[1, 1] ~ Dirac(2) + x[2, 1] ~ Dirac(3) + x[3] ~ Dirac(4) + y ~ Dirac(5) + x[4] ~ Dirac(6) + x[:, 3] ~ arraydist([Dirac(7), Dirac(8)]) + x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)]) + return s, x, y + end +demo (generic function with 2 methods) + +julia> demo()() +(1, Any[2.0 4.0 7 10; 3.0 6.0 8 9], 5) + +julia> varnames_to_ranges(demo()) +Dict{AbstractPPL.VarName, UnitRange{Int64}} with 8 entries: + s => 1:1 + x[4] => 5:5 + x[:,3] => 6:7 + x[1,1] => 2:2 + x[2,1] => 3:3 + x[[2, 1],4] => 8:9 + x[3] => 4:4 + y => 10:10 +``` +""" +function varnames_to_ranges end + +varnames_to_ranges(model::DynamicPPL.Model) = varnames_to_ranges(DynamicPPL.VarInfo(model)) +function varnames_to_ranges(varinfo::DynamicPPL.UntypedVarInfo) + return varnames_to_ranges(varinfo.metadata) +end +function varnames_to_ranges(varinfo::DynamicPPL.TypedVarInfo) + offset = 0 + dicts = map(varinfo.metadata) do md + vns2ranges = varnames_to_ranges(md) + vals = collect(values(vns2ranges)) + vals_offset = map(r -> offset .+ r, vals) + offset += reduce((curr, r) -> max(curr, r[end]), vals; init=0) + Dict(zip(keys(vns2ranges), vals_offset)) + end + + return reduce(merge, dicts) +end +function varnames_to_ranges(metadata::DynamicPPL.Metadata) + idcs = map(Base.Fix1(getindex, metadata.idcs), metadata.vns) + ranges = metadata.ranges[idcs] + return Dict(zip(metadata.vns, ranges)) +end + +function pathfinder( + model::DynamicPPL.Model; + rng=Random.GLOBAL_RNG, + init_scale=2, + init_sampler=UniformSampler(init_scale), + init=nothing, + kwargs..., +) + var_names = flattened_varnames_list(model) + prob = Turing.optim_problem(model, Turing.MAP(); constrained=false, init_theta=init) + init_sampler(rng, prob.prob.u0) + result = pathfinder(prob.prob; rng, input=model, kwargs...) + draws = reduce(vcat, transpose.(prob.transform.(eachcol(result.draws)))) + chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result)) + result_new = Accessors.@set result.draws_transformed = chns + return result_new +end + +function multipathfinder( + model::DynamicPPL.Model, + ndraws::Int; + rng=Random.GLOBAL_RNG, + init_scale=2, + init_sampler=UniformSampler(init_scale), + nruns::Int, + kwargs..., +) + var_names = flattened_varnames_list(model) + fun = Turing.optim_function(model, Turing.MAP(); constrained=false) + init1 = fun.init() + init = [init_sampler(rng, init1)] + for _ in 2:nruns + push!(init, init_sampler(rng, deepcopy(init1))) + end + result = multipathfinder(fun.func, ndraws; rng, input=model, init, kwargs...) + draws = reduce(vcat, transpose.(fun.transform.(eachcol(result.draws)))) + chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result)) + result_new = Accessors.@set result.draws_transformed = chns + return result_new +end diff --git a/test/integration/Turing/Project.toml b/test/integration/Turing/Project.toml new file mode 100644 index 00000000..013572c9 --- /dev/null +++ b/test/integration/Turing/Project.toml @@ -0,0 +1,10 @@ +[deps] +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" + +[compat] +Pathfinder = "0.4" +Turing = "0.21" +julia = "1.6" diff --git a/test/integration/Turing/runtests.jl b/test/integration/Turing/runtests.jl new file mode 100644 index 00000000..de5ee966 --- /dev/null +++ b/test/integration/Turing/runtests.jl @@ -0,0 +1,45 @@ +using Pathfinder, Random, Test, Turing + +Random.seed!(0) + +@model function regression_model(x, y) + σ ~ truncated(Normal(); lower=0) + α ~ Normal() + β ~ filldist(Normal(), size(x, 2)) + y_hat = muladd(x, β, α) + y .~ Normal.(y_hat, σ) + return (; y) +end + +@testset "Turing integration" begin + x = 0:0.01:1 + y = sin.(x) .+ randn.() .* 0.2 .+ x + X = [x x .^ 2 x .^ 3] + model = regression_model(X, y) + expected_param_names = Symbol.(["α", "β[1]", "β[2]", "β[3]", "σ"]) + + result = pathfinder(model; ndraws=10_000) + @test result isa PathfinderResult + @test result.input === model + @test size(result.draws) == (5, 10_000) + @test result.draws_transformed isa MCMCChains.Chains + @test result.draws_transformed.info.pathfinder_result isa PathfinderResult + @test sort(names(result.draws_transformed)) == expected_param_names + @test all(>(0), result.draws_transformed[:σ]) + init_params = Vector(result.draws_transformed.value[1, :, 1]) + chns = sample(model, NUTS(), 10_000; init_params, progress=false) + @test mean(chns).nt.mean ≈ mean(result.draws_transformed).nt.mean rtol = 0.1 + + result = multipathfinder(model, 10_000; nruns=4) + @test result isa MultiPathfinderResult + @test result.input === model + @test size(result.draws) == (5, 10_000) + @test length(result.pathfinder_results) == 4 + @test result.draws_transformed isa MCMCChains.Chains + @test result.draws_transformed.info.pathfinder_result isa MultiPathfinderResult + @test sort(names(result.draws_transformed)) == expected_param_names + @test all(>(0), result.draws_transformed[:σ]) + init_params = Vector(result.draws_transformed.value[1, :, 1]) + chns = sample(model, NUTS(), 10_000; init_params, progress=false) + @test mean(chns).nt.mean ≈ mean(result.draws_transformed).nt.mean rtol = 0.1 +end