Skip to content

Commit

Permalink
Add Turing support (#44)
Browse files Browse the repository at this point in the history
* Add WIP turing converters

* Fix implementations

* Conditionally load turing integration

* Store chains in result

* Increment patch number

* Add Turing integration test

* Pass model to functions

* Test more cases

* Test means are approximately equal

* Increase tolerance
  • Loading branch information
sethaxen authored May 2, 2022
1 parent 4e5717d commit 9ec7e6b
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/IntegrationTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
package:
- DynamicHMC
- AdvancedHMC
- Turing
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "Pathfinder"
uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -24,6 +25,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractDifferentiation = "0.4"
Accessors = "0.1"
Distributions = "0.25"
Folds = "0.2"
ForwardDiff = "0.10"
Expand Down
3 changes: 3 additions & 0 deletions src/Pathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ function __init__()
Requires.@require AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" begin
include("integration/advancedhmc.jl")
end
Requires.@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin
include("integration/turing.jl")
end
end

end
161 changes: 161 additions & 0 deletions src/integration/turing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
using .Turing: Turing, DynamicPPL, MCMCChains
using Accessors: Accessors

# utilities for working with Turing model parameter names using only the DynamicPPL API

"""
flattened_varnames_list(model::DynamicPPL.Model) -> Vector{Symbol}
Get a vector of varnames as `Symbol`s with one-to-one correspondence to the
flattened parameter vector.
```julia
julia> @model function demo()
s ~ Dirac(1)
x = Matrix{Float64}(undef, 2, 4)
x[1, 1] ~ Dirac(2)
x[2, 1] ~ Dirac(3)
x[3] ~ Dirac(4)
y ~ Dirac(5)
x[4] ~ Dirac(6)
x[:, 3] ~ arraydist([Dirac(7), Dirac(8)])
x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)])
return s, x, y
end
demo (generic function with 2 methods)
julia> flattened_varnames_list(demo())
10-element Vector{Symbol}:
:s
Symbol("x[1,1]")
Symbol("x[2,1]")
Symbol("x[3]")
Symbol("x[4]")
Symbol("x[:,3][1]")
Symbol("x[:,3][2]")
Symbol("x[[2, 1],4][1]")
Symbol("x[[2, 1],4][2]")
:y
```
"""
function flattened_varnames_list(model::DynamicPPL.Model)
varnames_ranges = varnames_to_ranges(model)
nsyms = maximum(maximum, values(varnames_ranges))
syms = Vector{Symbol}(undef, nsyms)
for (var_name, range) in varnames_to_ranges(model)
sym = Symbol(var_name)
if length(range) == 1
syms[range[begin]] = sym
continue
end
for i in eachindex(range)
syms[range[i]] = Symbol("$sym[$i]")
end
end
return syms
end

# code snippet shared by @torfjelde
"""
varnames_to_ranges(model::DynamicPPL.Model)
varnames_to_ranges(model::DynamicPPL.VarInfo)
varnames_to_ranges(model::DynamicPPL.Metadata)
Get `Dict` mapping variable names in model to their ranges in a corresponding parameter vector.
# Examples
```julia
julia> @model function demo()
s ~ Dirac(1)
x = Matrix{Float64}(undef, 2, 4)
x[1, 1] ~ Dirac(2)
x[2, 1] ~ Dirac(3)
x[3] ~ Dirac(4)
y ~ Dirac(5)
x[4] ~ Dirac(6)
x[:, 3] ~ arraydist([Dirac(7), Dirac(8)])
x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)])
return s, x, y
end
demo (generic function with 2 methods)
julia> demo()()
(1, Any[2.0 4.0 7 10; 3.0 6.0 8 9], 5)
julia> varnames_to_ranges(demo())
Dict{AbstractPPL.VarName, UnitRange{Int64}} with 8 entries:
s => 1:1
x[4] => 5:5
x[:,3] => 6:7
x[1,1] => 2:2
x[2,1] => 3:3
x[[2, 1],4] => 8:9
x[3] => 4:4
y => 10:10
```
"""
function varnames_to_ranges end

varnames_to_ranges(model::DynamicPPL.Model) = varnames_to_ranges(DynamicPPL.VarInfo(model))
function varnames_to_ranges(varinfo::DynamicPPL.UntypedVarInfo)
return varnames_to_ranges(varinfo.metadata)
end
function varnames_to_ranges(varinfo::DynamicPPL.TypedVarInfo)
offset = 0
dicts = map(varinfo.metadata) do md
vns2ranges = varnames_to_ranges(md)
vals = collect(values(vns2ranges))
vals_offset = map(r -> offset .+ r, vals)
offset += reduce((curr, r) -> max(curr, r[end]), vals; init=0)
Dict(zip(keys(vns2ranges), vals_offset))
end

return reduce(merge, dicts)
end
function varnames_to_ranges(metadata::DynamicPPL.Metadata)
idcs = map(Base.Fix1(getindex, metadata.idcs), metadata.vns)
ranges = metadata.ranges[idcs]
return Dict(zip(metadata.vns, ranges))
end

function pathfinder(
model::DynamicPPL.Model;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=UniformSampler(init_scale),
init=nothing,
kwargs...,
)
var_names = flattened_varnames_list(model)
prob = Turing.optim_problem(model, Turing.MAP(); constrained=false, init_theta=init)
init_sampler(rng, prob.prob.u0)
result = pathfinder(prob.prob; rng, input=model, kwargs...)
draws = reduce(vcat, transpose.(prob.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns
return result_new
end

function multipathfinder(
model::DynamicPPL.Model,
ndraws::Int;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=UniformSampler(init_scale),
nruns::Int,
kwargs...,
)
var_names = flattened_varnames_list(model)
fun = Turing.optim_function(model, Turing.MAP(); constrained=false)
init1 = fun.init()
init = [init_sampler(rng, init1)]
for _ in 2:nruns
push!(init, init_sampler(rng, deepcopy(init1)))
end
result = multipathfinder(fun.func, ndraws; rng, input=model, init, kwargs...)
draws = reduce(vcat, transpose.(fun.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns
return result_new
end
10 changes: 10 additions & 0 deletions test/integration/Turing/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
Pathfinder = "0.4"
Turing = "0.21"
julia = "1.6"
45 changes: 45 additions & 0 deletions test/integration/Turing/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Pathfinder, Random, Test, Turing

Random.seed!(0)

@model function regression_model(x, y)
σ ~ truncated(Normal(); lower=0)
α ~ Normal()
β ~ filldist(Normal(), size(x, 2))
y_hat = muladd(x, β, α)
y .~ Normal.(y_hat, σ)
return (; y)
end

@testset "Turing integration" begin
x = 0:0.01:1
y = sin.(x) .+ randn.() .* 0.2 .+ x
X = [x x .^ 2 x .^ 3]
model = regression_model(X, y)
expected_param_names = Symbol.(["α", "β[1]", "β[2]", "β[3]", "σ"])

result = pathfinder(model; ndraws=10_000)
@test result isa PathfinderResult
@test result.input === model
@test size(result.draws) == (5, 10_000)
@test result.draws_transformed isa MCMCChains.Chains
@test result.draws_transformed.info.pathfinder_result isa PathfinderResult
@test sort(names(result.draws_transformed)) == expected_param_names
@test all(>(0), result.draws_transformed[])
init_params = Vector(result.draws_transformed.value[1, :, 1])
chns = sample(model, NUTS(), 10_000; init_params, progress=false)
@test mean(chns).nt.mean mean(result.draws_transformed).nt.mean rtol = 0.1

result = multipathfinder(model, 10_000; nruns=4)
@test result isa MultiPathfinderResult
@test result.input === model
@test size(result.draws) == (5, 10_000)
@test length(result.pathfinder_results) == 4
@test result.draws_transformed isa MCMCChains.Chains
@test result.draws_transformed.info.pathfinder_result isa MultiPathfinderResult
@test sort(names(result.draws_transformed)) == expected_param_names
@test all(>(0), result.draws_transformed[])
init_params = Vector(result.draws_transformed.value[1, :, 1])
chns = sample(model, NUTS(), 10_000; init_params, progress=false)
@test mean(chns).nt.mean mean(result.draws_transformed).nt.mean rtol = 0.1
end

2 comments on commit 9ec7e6b

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/59521

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.1 -m "<description of version>" 9ec7e6b1c0529379b837d39a2d5d4ca6c99847e2
git push origin v0.4.1

Please sign in to comment.