Skip to content

Commit

Permalink
Almost there
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored and frankschae committed Jun 29, 2024
1 parent 26120ad commit ed0f98f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
19 changes: 16 additions & 3 deletions src/callback_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ function _setup_reverse_callbacks(
du = first(get_tmp_cache(integrator))
λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S)

if sensealg isa GaussAdjoint
dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache
recursive_copyto!(dgrad, 0)
end

# if save_positions[2] = false, then the right limit is not saved. Thus, for
# the QuadratureAdjoint we would need to lift y from the left to the right limit.
# However, one also needs to update dgrad later on.
Expand Down Expand Up @@ -339,7 +344,10 @@ function _setup_reverse_callbacks(
vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS;
dgrad = dgrad, dy = dy)

dgrad !== nothing && (dgrad .*= -1)
if dgrad !== nothing && !(sensealg isa QuadratureAdjoint)
dgrad .*= -1
end

if cb isa Union{ContinuousCallback, VectorContinuousCallback}
# second correction to correct for left limit
@unpack Lu_left = correction
Expand All @@ -358,8 +366,13 @@ function _setup_reverse_callbacks(

λ .=

if !(sensealg isa QuadratureAdjoint) && !(sensealg isa GaussAdjoint)
grad .-= dgrad
if sensealg isa GaussAdjoint
@assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad

#recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad)
elseif !(sensealg isa QuadratureAdjoint)
grad .= dgrad
end
end

Expand Down
22 changes: 13 additions & 9 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
G}
G, SAlg <: GaussAdjoint}
sol::S
p::pType
y::uType
Expand All @@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
f_cache::rateType
pJ::PJT
paramjac_config::PJC
sensealg::GaussAdjoint
sensealg::SAlg
dgdp_cache::DGP
dgdp::G
end

struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
Alg <: GaussAdjoint,
uType, SType, CPS, pType,
fType <: DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction
fType <: DiffEqBase.AbstractDiffEqFunction,
GI <: GaussIntegrand,
ICB} <: SensitivityFunction
diffcache::C
sensealg::Alg
discrete::Bool
Expand All @@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
checkpoint_sol::CPS
prob::pType
f::fType
GaussInt::GaussIntegrand
GaussInt::GI
integrating_cb::ICB
end

TruncatedStacktraces.@truncate_stacktrace ODEGaussAdjointSensitivityFunction
Expand All @@ -41,7 +44,7 @@ end
function ODEGaussAdjointSensitivityFunction(
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
f, alg,
checkpoints, tols, tstops = nothing;
checkpoints, integrating_cb, tols, tstops = nothing;
tspan = reverse(sol.prob.tspan))
checkpointing = ischeckpointing(sensealg, sol)
(checkpointing && checkpoints === nothing) &&
Expand Down Expand Up @@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction(
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
quad = true)
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
y, sol, checkpoint_sol, sol.prob, f, gaussint)
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
end

function Gaussfindcursor(intervals, t)
Expand Down Expand Up @@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
end

@noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg,
GaussInt::GaussIntegrand,
GaussInt::GaussIntegrand, integrating_cb,
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
Expand Down Expand Up @@ -275,7 +278,7 @@ end
λ = zero(u0)
end
sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol,
dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)

init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end]
Expand Down Expand Up @@ -565,7 +568,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,

if sol.prob isa ODEProblem
adj_prob, cb2, rcb = ODEAdjointProblem(
sol, sensealg, alg, integrand, t, dgdu_discrete,
sol, sensealg, alg, integrand, integrating_cb,
t, dgdu_discrete,
dgdp_discrete,
dgdu_continuous, dgdp_continuous, g, Val(true);
checkpoints = checkpoints,
Expand Down

0 comments on commit ed0f98f

Please sign in to comment.