diff --git a/src/callback_tracking.jl b/src/callback_tracking.jl index 295ef6326..a0e752e70 100644 --- a/src/callback_tracking.jl +++ b/src/callback_tracking.jl @@ -272,6 +272,11 @@ function _setup_reverse_callbacks( du = first(get_tmp_cache(integrator)) λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S) + if sensealg isa GaussAdjoint + dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache + recursive_copyto!(dgrad, 0) + end + # if save_positions[2] = false, then the right limit is not saved. Thus, for # the QuadratureAdjoint we would need to lift y from the left to the right limit. # However, one also needs to update dgrad later on. @@ -339,7 +344,10 @@ function _setup_reverse_callbacks( vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS; dgrad = dgrad, dy = dy) - dgrad !== nothing && (dgrad .*= -1) + if dgrad !== nothing && !(sensealg isa QuadratureAdjoint) + dgrad .*= -1 + end + if cb isa Union{ContinuousCallback, VectorContinuousCallback} # second correction to correct for left limit @unpack Lu_left = correction @@ -358,8 +366,13 @@ function _setup_reverse_callbacks( λ .= dλ - if !(sensealg isa QuadratureAdjoint) - grad .-= dgrad + if sensealg isa GaussAdjoint + @assert integrator.f.f isa ODEGaussAdjointSensitivityFunction + integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad + + #recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad) + elseif !(sensealg isa QuadratureAdjoint) + grad .= dgrad end end diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..39a4faa48 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -1,5 +1,5 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP, - G} + G, SAlg <: GaussAdjoint} sol::S p::pType y::uType @@ -8,7 +8,7 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG f_cache::rateType pJ::PJT paramjac_config::PJC - sensealg::GaussAdjoint + sensealg::SAlg dgdp_cache::DGP dgdp::G end @@ -16,7 +16,9 @@ end struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, Alg <: GaussAdjoint, uType, SType, CPS, pType, - fType <: DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction + fType <: DiffEqBase.AbstractDiffEqFunction, + GI <: GaussIntegrand, + ICB} <: SensitivityFunction diffcache::C sensealg::Alg discrete::Bool @@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, checkpoint_sol::CPS prob::pType f::fType - GaussInt::GaussIntegrand + GaussInt::GI + integrating_cb::ICB end TruncatedStacktraces.@truncate_stacktrace ODEGaussAdjointSensitivityFunction @@ -41,7 +44,7 @@ end function ODEGaussAdjointSensitivityFunction( g, sensealg, gaussint, discrete, sol, dgdu, dgdp, f, alg, - checkpoints, tols, tstops = nothing; + checkpoints, integrating_cb, tols, tstops = nothing; tspan = reverse(sol.prob.tspan)) checkpointing = ischeckpointing(sensealg, sol) (checkpointing && checkpoints === nothing) && @@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction( g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg; quad = true) return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete, - y, sol, checkpoint_sol, sol.prob, f, gaussint) + y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb) end function Gaussfindcursor(intervals, t) @@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true end @noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg, - GaussInt::GaussIntegrand, + GaussInt::GaussIntegrand, integrating_cb, t = nothing, dgdu_discrete::DG1 = nothing, dgdp_discrete::DG2 = nothing, @@ -275,7 +278,7 @@ end λ = zero(u0) end sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol, - dgdu_continuous, dgdp_continuous, f, alg, checkpoints, + dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb, (reltol = reltol, abstol = abstol), tstops, tspan = tspan) init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end] @@ -565,7 +568,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing, if sol.prob isa ODEProblem adj_prob, cb2, rcb = ODEAdjointProblem( - sol, sensealg, alg, integrand, t, dgdu_discrete, + sol, sensealg, alg, integrand, integrating_cb, + t, dgdu_discrete, dgdp_discrete, dgdu_continuous, dgdp_continuous, g, Val(true); checkpoints = checkpoints, diff --git a/test/callbacks/continuous_callbacks.jl b/test/callbacks/continuous_callbacks.jl index 41d5e951a..a5d0a1690 100644 --- a/test/callbacks/continuous_callbacks.jl +++ b/test/callbacks/continuous_callbacks.jl @@ -291,5 +291,9 @@ println("Continuous Callbacks") sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()) gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1] @test gFD≈gZy rtol=1e-10 + + sensealg = GaussAdjoint(autojacvec = EnzymeVJP()) + gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1] + @test gFD≈gZy rtol=1e-10 end end diff --git a/test/callbacks/continuous_vs_discrete.jl b/test/callbacks/continuous_vs_discrete.jl index 6d2f3be31..cf6e6f240 100644 --- a/test/callbacks/continuous_vs_discrete.jl +++ b/test/callbacks/continuous_vs_discrete.jl @@ -162,6 +162,13 @@ function test_continuous_wrt_discrete_callback() saveat = tspan[2], save_start = false)), u0, p) + du03, dp3 = Zygote.gradient( + (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = GaussAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + dstuff = ForwardDiff.gradient( (θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], callback = cb, saveat = tspan[2], @@ -173,8 +180,12 @@ function test_continuous_wrt_discrete_callback() @test dp1 ≈ dstuff[3:4] @test du02 ≈ dstuff[1:2] @test dp2 ≈ dstuff[3:4] + @test du03 ≈ dstuff[1:2] + @test dp3 ≈ dstuff[3:4] @test du01 ≈ du02 @test dp1 ≈ dp2 + @test du01 ≈ du03 + @test dp1 ≈ dp3 end @testset "Compare continuous with discrete callbacks" begin diff --git a/test/callbacks/discrete_callbacks.jl b/test/callbacks/discrete_callbacks.jl index 817cea35b..3e475dcdd 100644 --- a/test/callbacks/discrete_callbacks.jl +++ b/test/callbacks/discrete_callbacks.jl @@ -99,6 +99,14 @@ function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = fal sensealg = QuadratureAdjoint())), u0, p) + du05, dp5 = Zygote.gradient( + (u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = GaussAdjoint())), + u0, p) + dstuff = ForwardDiff.gradient( (θ) -> g(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:6], callback = cb, tstops = tstops, @@ -135,9 +143,11 @@ function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = fal @test du01≈du03c rtol=1e-7 @test du03 ≈ du03c @test du01 ≈ du04 + @test du01 ≈ du05 @test dp1 ≈ dp3 @test dp1 ≈ dp3c @test dp1≈dp4 rtol=1e-7 + @test dp1≈dp5 rtol=1e-7 cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, BacksolveAdjoint(autojacvec = ReverseDiffVJP())) diff --git a/test/callbacks/vector_continuous_callbacks.jl b/test/callbacks/vector_continuous_callbacks.jl index cbf37e943..8c4349cd1 100644 --- a/test/callbacks/vector_continuous_callbacks.jl +++ b/test/callbacks/vector_continuous_callbacks.jl @@ -29,6 +29,14 @@ function test_vector_continuous_callback(cb, g) sensealg = BacksolveAdjoint())), u0, p) + du02, dp2 = @time Zygote.gradient( + (u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = GaussAdjoint())), + u0, p) + dstuff = @time ForwardDiff.gradient( (θ) -> g(solve(prob, Tsit5(), u0 = θ[1:4], p = θ[5:6], callback = cb, @@ -38,6 +46,8 @@ function test_vector_continuous_callback(cb, g) @test du01 ≈ dstuff[1:4] @test dp1 ≈ dstuff[5:6] + @test du02 ≈ dstuff[1:4] + @test dp2 ≈ dstuff[5:6] end @testset "VectorContinuous callbacks" begin