From 5e701f6e44a7020fead729f0ad80a9e4778bf1e8 Mon Sep 17 00:00:00 2001 From: Saumil Shah Date: Mon, 26 Jun 2023 11:20:39 +0200 Subject: [PATCH 1/3] extend turing_inference args --- Project.toml | 2 ++ src/turing_inference.jl | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index a564f15c..f36843f4 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,9 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" diff --git a/src/turing_inference.jl b/src/turing_inference.jl index a36b13da..0808c2d6 100644 --- a/src/turing_inference.jl +++ b/src/turing_inference.jl @@ -9,14 +9,17 @@ function turing_inference(prob::DiffEqBase.DEProblem, likelihood = (u, p, t, σ) -> MvNormal(u, Diagonal((σ[1])^2 * ones(length(u)))), - num_samples = 1000, sampler = Turing.NUTS(0.65), + num_samples = 1000, + sampler = Turing.NUTS(0.65), + parallel_type = MCMCSerial(), + n_chains = 1, syms = [Turing.@varname(theta[i]) for i in 1:length(priors)], sample_u0 = false, save_idxs = nothing, progress = false, kwargs...) N = length(priors) - Turing.@model function mf(x, ::Type{T} = Float64) where {T <: Real} + Turing.@model function infer(x, ::Type{T} = Float64) where {T <: Real} theta = Vector{T}(undef, length(priors)) for i in 1:length(priors) theta[i] ~ NamedDist(priors[i], syms[i]) @@ -54,7 +57,14 @@ function turing_inference(prob::DiffEqBase.DEProblem, end false # Instantiate a Model object. - model = mf(data) - chn = sample(model, sampler, num_samples; progress = progress) + model = infer(data) + chn = sample( + model, + sampler, + parallel_type, + num_samples, + n_chains; + progress = progress + ) return chn end From 328abcd23717db603754965e461fd21178b572c3 Mon Sep 17 00:00:00 2001 From: Saumil Shah Date: Mon, 26 Jun 2023 12:01:28 +0200 Subject: [PATCH 2/3] added a test and updated docs --- docs/src/methods.md | 5 ++--- test/turing.jl | 21 +++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/docs/src/methods.md b/docs/src/methods.md index 7ec581c8..1eb303cb 100644 --- a/docs/src/methods.md +++ b/docs/src/methods.md @@ -38,7 +38,7 @@ parameter list. ```julia function turing_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors; likelihood_dist_priors, likelihood, num_samples = 1000, - sampler = Turing.NUTS(num_samples, 0.65), syms, kwargs...) + sampler = Turing.NUTS(num_samples, 0.65), parallel_type = MCMCSerial(), n_chains = 1, syms, kwargs...) end ``` @@ -49,8 +49,7 @@ observations for the differential equation system at time point `t[i]` (or highe dimensional). `priors` is an array of prior distributions for each parameter, specified via a [Distributions.jl](https://juliastats.github.io/Distributions.jl/dev/) -type. `num_samples` is the number of samples per MCMC chain. The extra `kwargs` are given to the internal differential -equation solver. +type. `num_samples` is the number of samples per MCMC chain. Sampling from multiple chains is possible, see [`Turing.jl` documentation](https://turinglang.org/v0.26/docs/using-turing/guide#sampling-multiple-chains), serially or parallelly using `parallel_type` and `n_chains`. The extra `kwargs` are given to the internal differential equation solver. ### dynamichmc_inference diff --git a/test/turing.jl b/test/turing.jl index 4bf4cd37..d7701ebc 100644 --- a/test/turing.jl +++ b/test/turing.jl @@ -1,5 +1,7 @@ using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools using Test, Distributions, SteadyStateDiffEq +using Turing + println("One parameter case") f1 = @ode_def begin dx = a * x - x * y @@ -14,25 +16,28 @@ randomized = VectorOfArray([(sol(t[i]) + 0.01randn(2)) for i in 1:length(t)]) data = convert(Array, randomized) priors = [Normal(1.5, 0.01)] -bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, - syms = [:a]) +bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, syms = [:a]) @show bayesian_result @test mean(get(bayesian_result, :a)[1])≈1.5 atol=3e-1 -bayesian_result = turing_inference(prob1, Rosenbrock23(autodiff = false), t, data, priors; - num_samples = 500, - syms = [:a]) +bayesian_result = turing_inference(prob1, Rosenbrock23(autodiff = false), t, data, priors; num_samples = 500, syms = [:a]) bayesian_result = turing_inference(prob1, Rosenbrock23(), t, data, priors; num_samples = 500, syms = [:a]) +# --- test Multithreaded sampling +println("Multithreaded case") +result_threaded = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, syms = [:a], parallel_type=MCMCThreads(), n_chains=2) + +@test length(result_threaded.value.axes[3]) == 2 +@test mean(get(result_threaded, :a)[1])≈1.5 atol=3e-1 +# --- + priors = [Normal(1.0, 0.01), Normal(1.0, 0.01), Normal(1.5, 0.01)] -bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, - sample_u0 = true, - syms = [:u1, :u2, :a]) +bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, sample_u0 = true, syms = [:u1, :u2, :a]) @test mean(get(bayesian_result, :a)[1])≈1.5 atol=3e-1 @test mean(get(bayesian_result, :u1)[1])≈1.0 atol=3e-1 From 5004545f68e6fd130a768a19cfbbddde8056f584 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Mon, 26 Jun 2023 16:23:30 +0530 Subject: [PATCH 3/3] Apply suggestions from code review --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index f36843f4..a564f15c 100644 --- a/Project.toml +++ b/Project.toml @@ -19,9 +19,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" -ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"