From b973ded0d9822083c78e403e11be2e68345b384d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 29 May 2024 04:04:17 +0200 Subject: [PATCH 1/5] Fix GaussAdjoint with callbacks --- src/callback_tracking.jl | 2 +- test/callbacks/continuous_callbacks.jl | 4 ++++ test/callbacks/continuous_vs_discrete.jl | 11 +++++++++++ test/callbacks/discrete_callbacks.jl | 10 ++++++++++ test/callbacks/vector_continuous_callbacks.jl | 10 ++++++++++ 5 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/callback_tracking.jl b/src/callback_tracking.jl index 295ef6326..daa2480d0 100644 --- a/src/callback_tracking.jl +++ b/src/callback_tracking.jl @@ -358,7 +358,7 @@ function _setup_reverse_callbacks( λ .= dλ - if !(sensealg isa QuadratureAdjoint) + if !(sensealg isa QuadratureAdjoint) && !(sensealg isa GaussAdjoint) grad .-= dgrad end end diff --git a/test/callbacks/continuous_callbacks.jl b/test/callbacks/continuous_callbacks.jl index 41d5e951a..a5d0a1690 100644 --- a/test/callbacks/continuous_callbacks.jl +++ b/test/callbacks/continuous_callbacks.jl @@ -291,5 +291,9 @@ println("Continuous Callbacks") sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()) gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1] @test gFD≈gZy rtol=1e-10 + + sensealg = GaussAdjoint(autojacvec = EnzymeVJP()) + gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1] + @test gFD≈gZy rtol=1e-10 end end diff --git a/test/callbacks/continuous_vs_discrete.jl b/test/callbacks/continuous_vs_discrete.jl index 6d2f3be31..7570edb23 100644 --- a/test/callbacks/continuous_vs_discrete.jl +++ b/test/callbacks/continuous_vs_discrete.jl @@ -162,6 +162,13 @@ function test_continuous_wrt_discrete_callback() saveat = tspan[2], save_start = false)), u0, p) + du03, dp3 = Zygote.gradient( + (u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, + sensealg = GaussAdjoint(), + saveat = tspan[2], save_start = false)), + u0, p) + dstuff = ForwardDiff.gradient( (θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], callback = cb, saveat = tspan[2], @@ -173,8 +180,12 @@ function test_continuous_wrt_discrete_callback() @test dp1 ≈ dstuff[3:4] @test du02 ≈ dstuff[1:2] @test dp2 ≈ dstuff[3:4] + @test du03 ≈ dstuff[1:2] + @test dp3 ≈ dstuff[3:4] @test du01 ≈ du02 @test dp1 ≈ dp2 + @test du01 ≈ du03 + @test dp1 ≈ dp3 end @testset "Compare continuous with discrete callbacks" begin diff --git a/test/callbacks/discrete_callbacks.jl b/test/callbacks/discrete_callbacks.jl index 817cea35b..3e475dcdd 100644 --- a/test/callbacks/discrete_callbacks.jl +++ b/test/callbacks/discrete_callbacks.jl @@ -99,6 +99,14 @@ function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = fal sensealg = QuadratureAdjoint())), u0, p) + du05, dp5 = Zygote.gradient( + (u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, tstops = tstops, + abstol = abstol, reltol = reltol, + saveat = savingtimes, + sensealg = GaussAdjoint())), + u0, p) + dstuff = ForwardDiff.gradient( (θ) -> g(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:6], callback = cb, tstops = tstops, @@ -135,9 +143,11 @@ function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = fal @test du01≈du03c rtol=1e-7 @test du03 ≈ du03c @test du01 ≈ du04 + @test du01 ≈ du05 @test dp1 ≈ dp3 @test dp1 ≈ dp3c @test dp1≈dp4 rtol=1e-7 + @test dp1≈dp5 rtol=1e-7 cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, BacksolveAdjoint(autojacvec = ReverseDiffVJP())) diff --git a/test/callbacks/vector_continuous_callbacks.jl b/test/callbacks/vector_continuous_callbacks.jl index cbf37e943..8c4349cd1 100644 --- a/test/callbacks/vector_continuous_callbacks.jl +++ b/test/callbacks/vector_continuous_callbacks.jl @@ -29,6 +29,14 @@ function test_vector_continuous_callback(cb, g) sensealg = BacksolveAdjoint())), u0, p) + du02, dp2 = @time Zygote.gradient( + (u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p, + callback = cb, abstol = abstol, + reltol = reltol, + saveat = savingtimes, + sensealg = GaussAdjoint())), + u0, p) + dstuff = @time ForwardDiff.gradient( (θ) -> g(solve(prob, Tsit5(), u0 = θ[1:4], p = θ[5:6], callback = cb, @@ -38,6 +46,8 @@ function test_vector_continuous_callback(cb, g) @test du01 ≈ dstuff[1:4] @test dp1 ≈ dstuff[5:6] + @test du02 ≈ dstuff[1:4] + @test dp2 ≈ dstuff[5:6] end @testset "VectorContinuous callbacks" begin From 9d57682f55c58f348df3b447cbbcb0fb4340c98d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 29 May 2024 08:58:25 +0200 Subject: [PATCH 2/5] Allow for Enzyme choice with callbacks automatically --- src/sensitivity_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 0353626ed..8203d8e3a 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -386,7 +386,7 @@ function adjoint_sensitivities(sol, args...; setvjp(sensealg, ZygoteVJP()) end else - _sensealg = setvjp(sensealg, ReverseDiffVJP()) + setvjp(sensealg, ZygoteVJP()) end return try From 8eeaddc76471bc8a9abd7c1f95a382437f36c15b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 7 Jun 2024 00:15:09 -0400 Subject: [PATCH 3/5] fix autojacvec choices --- test/callbacks/continuous_vs_discrete.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/callbacks/continuous_vs_discrete.jl b/test/callbacks/continuous_vs_discrete.jl index 7570edb23..cf6e6f240 100644 --- a/test/callbacks/continuous_vs_discrete.jl +++ b/test/callbacks/continuous_vs_discrete.jl @@ -168,7 +168,7 @@ function test_continuous_wrt_discrete_callback() sensealg = GaussAdjoint(), saveat = tspan[2], save_start = false)), u0, p) - + dstuff = ForwardDiff.gradient( (θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4], callback = cb, saveat = tspan[2], From 26120ad32606a94ddc3cdbb111e60553b342458c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 7 Jun 2024 05:22:11 -0400 Subject: [PATCH 4/5] revert enzyme choice on callbacks --- src/sensitivity_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 8203d8e3a..0353626ed 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -386,7 +386,7 @@ function adjoint_sensitivities(sol, args...; setvjp(sensealg, ZygoteVJP()) end else - setvjp(sensealg, ZygoteVJP()) + _sensealg = setvjp(sensealg, ReverseDiffVJP()) end return try From ed0f98feee61000d8f772454c71102dcd7a35671 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 7 Jun 2024 10:41:01 -0400 Subject: [PATCH 5/5] Almost there --- src/callback_tracking.jl | 19 ++++++++++++++++--- src/gauss_adjoint.jl | 22 +++++++++++++--------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/callback_tracking.jl b/src/callback_tracking.jl index daa2480d0..a0e752e70 100644 --- a/src/callback_tracking.jl +++ b/src/callback_tracking.jl @@ -272,6 +272,11 @@ function _setup_reverse_callbacks( du = first(get_tmp_cache(integrator)) λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S) + if sensealg isa GaussAdjoint + dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache + recursive_copyto!(dgrad, 0) + end + # if save_positions[2] = false, then the right limit is not saved. Thus, for # the QuadratureAdjoint we would need to lift y from the left to the right limit. # However, one also needs to update dgrad later on. @@ -339,7 +344,10 @@ function _setup_reverse_callbacks( vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS; dgrad = dgrad, dy = dy) - dgrad !== nothing && (dgrad .*= -1) + if dgrad !== nothing && !(sensealg isa QuadratureAdjoint) + dgrad .*= -1 + end + if cb isa Union{ContinuousCallback, VectorContinuousCallback} # second correction to correct for left limit @unpack Lu_left = correction @@ -358,8 +366,13 @@ function _setup_reverse_callbacks( λ .= dλ - if !(sensealg isa QuadratureAdjoint) && !(sensealg isa GaussAdjoint) - grad .-= dgrad + if sensealg isa GaussAdjoint + @assert integrator.f.f isa ODEGaussAdjointSensitivityFunction + integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad + + #recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad) + elseif !(sensealg isa QuadratureAdjoint) + grad .= dgrad end end diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..39a4faa48 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -1,5 +1,5 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP, - G} + G, SAlg <: GaussAdjoint} sol::S p::pType y::uType @@ -8,7 +8,7 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG f_cache::rateType pJ::PJT paramjac_config::PJC - sensealg::GaussAdjoint + sensealg::SAlg dgdp_cache::DGP dgdp::G end @@ -16,7 +16,9 @@ end struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, Alg <: GaussAdjoint, uType, SType, CPS, pType, - fType <: DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction + fType <: DiffEqBase.AbstractDiffEqFunction, + GI <: GaussIntegrand, + ICB} <: SensitivityFunction diffcache::C sensealg::Alg discrete::Bool @@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, checkpoint_sol::CPS prob::pType f::fType - GaussInt::GaussIntegrand + GaussInt::GI + integrating_cb::ICB end TruncatedStacktraces.@truncate_stacktrace ODEGaussAdjointSensitivityFunction @@ -41,7 +44,7 @@ end function ODEGaussAdjointSensitivityFunction( g, sensealg, gaussint, discrete, sol, dgdu, dgdp, f, alg, - checkpoints, tols, tstops = nothing; + checkpoints, integrating_cb, tols, tstops = nothing; tspan = reverse(sol.prob.tspan)) checkpointing = ischeckpointing(sensealg, sol) (checkpointing && checkpoints === nothing) && @@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction( g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg; quad = true) return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete, - y, sol, checkpoint_sol, sol.prob, f, gaussint) + y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb) end function Gaussfindcursor(intervals, t) @@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true end @noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg, - GaussInt::GaussIntegrand, + GaussInt::GaussIntegrand, integrating_cb, t = nothing, dgdu_discrete::DG1 = nothing, dgdp_discrete::DG2 = nothing, @@ -275,7 +278,7 @@ end λ = zero(u0) end sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol, - dgdu_continuous, dgdp_continuous, f, alg, checkpoints, + dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb, (reltol = reltol, abstol = abstol), tstops, tspan = tspan) init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end] @@ -565,7 +568,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing, if sol.prob isa ODEProblem adj_prob, cb2, rcb = ODEAdjointProblem( - sol, sensealg, alg, integrand, t, dgdu_discrete, + sol, sensealg, alg, integrand, integrating_cb, + t, dgdu_discrete, dgdp_discrete, dgdu_continuous, dgdp_continuous, g, Val(true); checkpoints = checkpoints,