Skip to content

Commit

Permalink
Fix logpdf to work with Tracked values
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Nov 5, 2018
1 parent 3414670 commit be19cb1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/dynamichmc_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ struct DynamicHMCPosterior
kwargs
end

# logpdf(dist,data::Flux.Tracker.TrackedReal) = Distributions.logpdf(dist,data.data)

function (P::DynamicHMCPosterior)(a)
@unpack alg, problem, likelihood, priors, kwargs = P
prob = remake(problem,u0 = convert.(eltype(a.a),problem.u0),p=a.a)
Expand All @@ -30,7 +32,7 @@ end
function dynamichmc_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors, transformations;
σ=0.01, ϵ=0.001, initial=Float64[], num_samples=1000,
kwargs...)
likelihood = sol -> sum( sum(logpdf.(Normal(0.0, σ), sol(t) .- data[:, i]))
likelihood = sol -> sum( sum((sol(t) .- data[:, i]).^2)
for (i, t) in enumerate(t) )
dynamichmc_inference(prob, alg, likelihood, priors, transformations;
ϵ=ϵ, initial=initial, num_samples=num_samples,
Expand Down
6 changes: 3 additions & 3 deletions test/dynamicHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ t = collect(range(1,stop=10,length=10)) # observation times
sol = solve(prob1,Tsit5())
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)
bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], as((a = asℝ₊,)),num_samples=10000)
bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], as((a = asℝ₊,)),num_samples=100)
@test mean(a.a for a in bayesian_result[1]) 1.5 atol=1e-1

# With hand-code likelihood function
Expand All @@ -26,15 +26,15 @@ for i = 1:3:length(data)
weights_[i] = 0
data[i] = 1e20 # to test that those points are indeed not used
end
logpdf(dist,data::Flux.Tracker.TrackedReal) = logpdf(dist,value(data))

likelihood = function (sol)
l = 0.0
for (i, t) in enumerate(t)
l += sum(((sol(t) - data[:, i]).^2) .* weights_[:,i])
end
return l
end
bayesian_result = dynamichmc_inference(prob1, Tsit5(), likelihood, [Truncated(Normal(1.5, 1), 0, 2)],as((a = asℝ₊,)))
bayesian_result = dynamichmc_inference(prob1, Tsit5(), likelihood, [Truncated(Normal(1.5, 1), 0, 2)],as((a = asℝ₊,)),num_samples=100)
@test mean(bayesian_result[1][1]) 1.5 atol=1e-1


Expand Down

0 comments on commit be19cb1

Please sign in to comment.