diff --git a/src/callback_tracking.jl b/src/callback_tracking.jl index daa2480d0..a0e752e70 100644 --- a/src/callback_tracking.jl +++ b/src/callback_tracking.jl @@ -272,6 +272,11 @@ function _setup_reverse_callbacks( du = first(get_tmp_cache(integrator)) λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S) + if sensealg isa GaussAdjoint + dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache + recursive_copyto!(dgrad, 0) + end + # if save_positions[2] = false, then the right limit is not saved. Thus, for # the QuadratureAdjoint we would need to lift y from the left to the right limit. # However, one also needs to update dgrad later on. @@ -339,7 +344,10 @@ function _setup_reverse_callbacks( vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS; dgrad = dgrad, dy = dy) - dgrad !== nothing && (dgrad .*= -1) + if dgrad !== nothing && !(sensealg isa QuadratureAdjoint) + dgrad .*= -1 + end + if cb isa Union{ContinuousCallback, VectorContinuousCallback} # second correction to correct for left limit @unpack Lu_left = correction @@ -358,8 +366,13 @@ function _setup_reverse_callbacks( λ .= dλ - if !(sensealg isa QuadratureAdjoint) && !(sensealg isa GaussAdjoint) - grad .-= dgrad + if sensealg isa GaussAdjoint + @assert integrator.f.f isa ODEGaussAdjointSensitivityFunction + integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad + + #recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad) + elseif !(sensealg isa QuadratureAdjoint) + grad .= dgrad end end diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..39a4faa48 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -1,5 +1,5 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP, - G} + G, SAlg <: GaussAdjoint} sol::S p::pType y::uType @@ -8,7 +8,7 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG f_cache::rateType pJ::PJT paramjac_config::PJC - sensealg::GaussAdjoint + sensealg::SAlg dgdp_cache::DGP dgdp::G end @@ -16,7 +16,9 @@ end struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, Alg <: GaussAdjoint, uType, SType, CPS, pType, - fType <: DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction + fType <: DiffEqBase.AbstractDiffEqFunction, + GI <: GaussIntegrand, + ICB} <: SensitivityFunction diffcache::C sensealg::Alg discrete::Bool @@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, checkpoint_sol::CPS prob::pType f::fType - GaussInt::GaussIntegrand + GaussInt::GI + integrating_cb::ICB end TruncatedStacktraces.@truncate_stacktrace ODEGaussAdjointSensitivityFunction @@ -41,7 +44,7 @@ end function ODEGaussAdjointSensitivityFunction( g, sensealg, gaussint, discrete, sol, dgdu, dgdp, f, alg, - checkpoints, tols, tstops = nothing; + checkpoints, integrating_cb, tols, tstops = nothing; tspan = reverse(sol.prob.tspan)) checkpointing = ischeckpointing(sensealg, sol) (checkpointing && checkpoints === nothing) && @@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction( g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg; quad = true) return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete, - y, sol, checkpoint_sol, sol.prob, f, gaussint) + y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb) end function Gaussfindcursor(intervals, t) @@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true end @noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg, - GaussInt::GaussIntegrand, + GaussInt::GaussIntegrand, integrating_cb, t = nothing, dgdu_discrete::DG1 = nothing, dgdp_discrete::DG2 = nothing, @@ -275,7 +278,7 @@ end λ = zero(u0) end sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol, - dgdu_continuous, dgdp_continuous, f, alg, checkpoints, + dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb, (reltol = reltol, abstol = abstol), tstops, tspan = tspan) init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end] @@ -565,7 +568,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing, if sol.prob isa ODEProblem adj_prob, cb2, rcb = ODEAdjointProblem( - sol, sensealg, alg, integrand, t, dgdu_discrete, + sol, sensealg, alg, integrand, integrating_cb, + t, dgdu_discrete, dgdp_discrete, dgdu_continuous, dgdp_continuous, g, Val(true); checkpoints = checkpoints,