Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic_inference update #60

Merged
merged 18 commits into from
Dec 1, 2018
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
julia 0.7
DiffEqBase 1.7.0
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
julia 1.0
DiffEqBase 4.28.1
#Mamba
Stan
Distributions
Expand All @@ -10,8 +10,8 @@ RecursiveArrayTools
ParameterizedFunctions
OrdinaryDiffEq
Parameters
#DiffWrappers
ContinuousTransformations 1.0.0
DynamicHMC 0.1.1
DynamicHMC
Distances
ApproxBayes
TransformVariables
LogDensityProblems
2 changes: 1 addition & 1 deletion src/DiffEqBayes.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqBayes
using DiffEqBase, Stan, Distributions, Turing, MacroTools
using OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools
using DynamicHMC, ContinuousTransformations
using DynamicHMC, TransformVariables, LogDensityProblems
using Parameters, Distributions, Optim
using Distances, ApproxBayes

Expand Down
53 changes: 11 additions & 42 deletions src/dynamichmc_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@ end

function (P::DynamicHMCPosterior)(a)
@unpack alg, problem, likelihood, priors, kwargs = P

prob = problem_new_parameters(problem, a)
prob = remake(problem,u0 = convert.(eltype(a.a),problem.u0),p=a.a)
sol = solve(prob, alg; kwargs...)
if any((s.retcode != :Success for s in sol))
ℓ = -Inf
else
ℓ = likelihood(sol)
end

if !isfinite(ℓ) && (ℓ ≠ -Inf)
ℓ = -Inf # protect against NaN etc, is it needed?
end
logpdf_sum = 0
logpdf_sum = 0.0
for i in length(a)
logpdf_sum += logpdf(priors[i], a[i])
end
Expand All @@ -31,9 +27,10 @@ end
function dynamichmc_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors, transformations;
σ=0.01, ϵ=0.001, initial=Float64[], num_samples=1000,
kwargs...)
likelihood = sol -> sum( sum(logpdf.(Normal(0.0, σ), sol(t) .- data[:, i]))
for (i, t) in enumerate(t) )

likelihood = function (sol)
sum( sum(logpdf.(Normal(0.0, σ), sol(t) .- data[:, i]))
for (i, t) in enumerate(t) )
end
dynamichmc_inference(prob, alg, likelihood, priors, transformations;
ϵ=ϵ, initial=initial, num_samples=num_samples,
kwargs...)
Expand All @@ -43,39 +40,11 @@ function dynamichmc_inference(prob::DiffEqBase.DEProblem, alg, likelihood, prior
ϵ=0.001, initial=Float64[], num_samples=1000,
kwargs...)
P = DynamicHMCPosterior(alg, prob, likelihood, priors, kwargs)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
PT = TransformedLogDensity(transformations, P)
PTG = LogDensityProblems.FluxGradientLogDensity(PT);

transformations_tuple = Tuple(transformations)
parameter_transformation = TransformationTuple(transformations_tuple) # assuming a > 0
PT = TransformLogLikelihood(P, parameter_transformation)
PTG = ForwardGradientWrapper(PT, zeros(length(priors)));

lower_bound = Float64[]
upper_bound = Float64[]

for i in priors
push!(lower_bound, minimum(i))
push!(upper_bound, maximum(i))
end

# If no initial position is given use local minimum near expectation of priors.
if length(initial) == 0
for i in priors
push!(initial, mean(i))
end
initial_opt = Optim.minimizer(optimize(a -> -P(a),lower_bound,upper_bound,initial,Fminbox(GradientDescent())))
end

initial_inverse_transformed = Float64[]
for i in 1:length(initial_opt)
para = TransformationTuple(transformations[i])
push!(initial_inverse_transformed,inverse(para, (initial_opt[i], ))[1])
end
#println(initial_inverse_transformed)
sample, NUTS_tuned = NUTS_init_tune_mcmc(PTG,
initial_inverse_transformed,
num_samples, ϵ=ϵ)

posterior = ungrouping_map(Vector, get_transformation(PT) ∘ get_position, sample)
chain, NUTS_tuned = NUTS_init_tune_mcmc(PTG,num_samples, ϵ=ϵ)
posterior = transform.(Ref(PTG.transformation), get_position.(chain));
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

return posterior, sample, NUTS_tuned
return posterior, chain, NUTS_tuned
end
2 changes: 1 addition & 1 deletion src/stan_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing;alg=:
setup_params = string(setup_params,"row_vector<lower=0>[$length_of_y] sigma$(i-1);")
end
end
tuple_hyper_params = tuple_hyper_params[1:endof(tuple_hyper_params)-1]
tuple_hyper_params = tuple_hyper_params[1:length(tuple_hyper_params)-1]
differential_equation = generate_differential_equation(f)
priors_string = string(generate_priors(f,priors))
stan_likelihood = stan_string(likelihood)
Expand Down
18 changes: 10 additions & 8 deletions test/dynamicHMC.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools
using DynamicHMC, DiffWrappers, ContinuousTransformations
using DynamicHMC, TransformVariables
using Parameters, Distributions, Optim
using Test

f1 = @ode_def_nohes LotkaVolterraTest1 begin
f1 = @ode_def LotkaVolterraTest1 begin
dx = a*x - x*y
dy = -3*y + x*y
end a
Expand All @@ -16,28 +17,29 @@ t = collect(range(1,stop=10,length=10)) # observation times
sol = solve(prob1,Tsit5())
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)
bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], as((a = asℝ₊,)))
@test mean(a.a for a in bayesian_result[1]) ≈ 1.5 atol=1e-1

bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], [bridge(ℝ, ℝ⁺, )])
@test mean(bayesian_result[1][1]) ≈ 1.5 atol=1e-1
# bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], as((a = as(Real,0,10),)))

# With hand-code likelihood function
weights_ = ones(data) # weighted data
weights_ = ones(size(data)) # weighted data
for i = 1:3:length(data)
weights_[i] = 0
data[i] = 1e20 # to test that those points are indeed not used
end
likelihood = function (sol)
l = 0.0
l = zero(eltype(first(sol)))
for (i, t) in enumerate(t)
l += sum(logpdf.(Normal(0.0, σ), sol(t) - data[:, i]) .* weights_[:,i])
end
return l
end
bayesian_result = dynamichmc_inference(prob1, Tsit5(), likelihood, [Truncated(Normal(1.5, 1), 0, 2)], [bridge(ℝ, ℝ⁺, )])
bayesian_result = dynamichmc_inference(prob1, Tsit5(), likelihood, [Truncated(Normal(1.5, 1), 0, 2)], as((a = asℝ₊,)))
@test mean(bayesian_result[1][1]) ≈ 1.5 atol=1e-1


f1 = @ode_def_nohes LotkaVolterraTest4 begin
f1 = @ode_def LotkaVolterraTest4 begin
dx = a*x - b*x*y
dy = -c*y + d*x*y
end a b c d
Expand Down
4 changes: 2 additions & 2 deletions test/stan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
RecursiveArrayTools, Distributions, Test

println("One parameter case")
f1 = @ode_def_nohes LotkaVolterraTest1 begin
f1 = @ode_def LotkaVolterraTest1 begin
dx = a*x - x*y
dy = -3y + x*y
end a
Expand All @@ -25,7 +25,7 @@ theta1 = bayesian_result.chain_results[:,["theta.1"],:]


println("Four parameter case")
f1 = @ode_def_nohes LotkaVolterraTest4 begin
f1 = @ode_def LotkaVolterraTest4 begin
dx = a*x - b*x*y
dy = -c*y + d*x*y
end a b c d
Expand Down