Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mooncake Inside Problems #1152

Merged
Merged
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -71,6 +72,7 @@ LinearSolve = "2"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
Mooncake = "0.4.52"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/manual/differential_equation_sensitivities.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ ZygoteVJP
EnzymeVJP
TrackerVJP
ReverseDiffVJP
MooncakeVJP
```

## More Details on Sensitivity Algorithm Choices
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Zer
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Mooncake: Mooncake
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
Expand Down
48 changes: 48 additions & 0 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg)
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf))
elseif autojacvec isa MooncakeVJP
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas, it is probably better to refactor these hard-coded branches (e.g., define an interface function that other packages can overload). It would help

  • autograd tools to integrate with SciMLSensitivity easily
  • move some existing autograd glue code into package extensions to avoid hard deps

It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autograd tools to integrate with SciMLSensitivity easily

Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?

move some existing autograd glue code into package extensions to avoid hard deps

Doesn't doesn't necessarily make sense. Most of the methods are used in the default method so they would be required to be loaded by default anyways?

It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle

That's the plan when it's able to handle this case well. Currently it's not able to.

Copy link

@yebai yebai Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?

Mooncake is getting a new forward mode (an attempt to improve ForwardDiff with GPU compatibility and fewer constraints; see here for more details), so @willtebbutt will likely need to modify these again in the near term.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that just require modifying https://github.com/SciML/SciMLSensitivity.jl/pull/1152/files#diff-1a15b4b5711133c125548ef7f1ca88f761bb124cffc8bfde8c13336968aaccd6R466 ? I don't see why that would touch this function and instead just dispatch on there.

I mean, if someone wants to do a refactor here that's perfectly fine. But I also don't see why it would be a high priority since it's not like new AD systems get added every year, and modifications to existing ones don't really touch this part of the code much. I would think the time would be better spent just trying to get DI up to speed than refactoring this old code.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss DI integration in #1040 if you want

pf = get_pf(autojacvec, prob, unwrappedf)
paramjac_config = get_paramjac_config(autojacvec, pf, p, f, y, _t)
elseif SciMLBase.has_paramjac(f) || quad || !(autojacvec isa Bool) ||
autojacvec isa EnzymeVJP
paramjac_config = nothing
Expand Down Expand Up @@ -460,6 +463,13 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p, f, y, _p, _t; numindvar,
return paramjac_config
end

function get_paramjac_config(::MooncakeVJP, pf, p, f, y, _t)
dy_mem = zero(y)
λ_mem = zero(y)
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
return cache, pf, λ_mem, dy_mem
end

function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
isRODE = nothing)
nothing
Expand Down Expand Up @@ -492,6 +502,44 @@ function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
end
end

function get_pf(::MooncakeVJP, prob, _f)
isinplace = DiffEqBase.isinplace(prob)
isRODE = isa(prob, RODEProblem)
pf = let f = _f
if isinplace && isRODE
function (out, u, _p, t, W)
f(out, u, _p, t, W)
return out
end
elseif isinplace
function (out, u, _p, t)
f(out, u, _p, t)
return out
end
elseif !isinplace && isRODE
function (out, u, _p, t, W)
out .= f(u, _p, t, W)
return out
end
else
# !isinplace
function (out, u, _p, t)
out .= f(u, _p, t)
return out
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
end
end
end
end

function mooncake_run_ad(paramjac_config, y, p, t, λ)
cache, pf, λ_mem, dy_mem = paramjac_config
λ_mem .= λ
dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t)
y_grad = cache.tangents[3]
p_grad = cache.tangents[4]
return dy, y_grad, p_grad
end

function getprob(S::SensitivityFunction)
(S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob
end
Expand Down
8 changes: 8 additions & 0 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
return
end

function _vecjacobian!(dλ, y, λ, p, t, S::SensitivityFunction, ::MooncakeVJP, dgrad, dy, W)
_dy, y_grad, p_grad = mooncake_run_ad(S.diffcache.paramjac_config, y, p, t, λ)
dy !== nothing && recursive_copyto!(dy, _dy)
dλ !== nothing && recursive_copyto!(dλ, y_grad)
dgrad !== nothing && recursive_copyto!(dgrad, p_grad)
return
end

function jacNoise!(λ, y, p, t, S::SensitivityFunction;
dgrad = nothing, dλ = nothing, dy = nothing)
_jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy)
Expand Down
7 changes: 7 additions & 0 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
pf = get_pf(sensealg.autojacvec, prob, f)
paramjac_config = get_paramjac_config(sensealg.autojacvec, pf, p, f, y, tspan[2])
pJ = nothing
elseif isautojacvec # Zygote
paramjac_config = nothing
pf = nothing
Expand Down Expand Up @@ -500,6 +504,9 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
elseif sensealg.autojacvec isa MooncakeVJP
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
out .= p_grad
else
error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint")
end
Expand Down
7 changes: 7 additions & 0 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
pf = get_pf(sensealg.autojacvec, prob, f)
paramjac_config = get_paramjac_config(sensealg.autojacvec, pf, p, f, y, tspan[2])
pJ = nothing
elseif isautojacvec # Zygote
paramjac_config = nothing
pf = nothing
Expand Down Expand Up @@ -288,6 +292,9 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
else
out[:] .= vec(tmp[1])
end
elseif sensealg.autojacvec isa MooncakeVJP
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
out .= p_grad
elseif sensealg.autojacvec isa EnzymeVJP
tmp3, tmp4, tmp6 = paramjac_config
tmp4 .= λ
Expand Down
17 changes: 17 additions & 0 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,23 @@ struct ReverseDiffVJP{compile} <: VJPChoice
ReverseDiffVJP(compile = false) = new{compile}()
end

"""
```julia
MooncakeVJP <: VJPChoice
```

Uses Mooncake.jl to compute the vector-Jacobian products.

Does not support GPUs (CuArrays).

## Constructor

```julia
MooncakeVJP()
```
"""
struct MooncakeVJP <: VJPChoice end

@inline convert_tspan(::ForwardDiffSensitivity{CS, CTS}) where {CS, CTS} = CTS
@inline convert_tspan(::Any) = nothing
@inline function alg_autodiff(alg::AbstractSensitivityAlgorithm{
Expand Down
70 changes: 59 additions & 11 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint())
_, easy_res15 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res16 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res142 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand All @@ -158,6 +166,10 @@ _, easy_res146 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discret
sensealg = GaussAdjoint(checkpointing = true,
autojacvec = false),
checkpoints = sol.t[1:500:end])
_, easy_res147 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
adj_prob = ODEAdjointProblem(sol,
QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14,
autojacvec = SciMLSensitivity.ReverseDiffVJP()),
Expand Down Expand Up @@ -189,11 +201,14 @@ res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12)
@test isapprox(res, easy_res12, rtol = 1e-9)
@test isapprox(res, easy_res13, rtol = 1e-9)
@test isapprox(res, easy_res14, rtol = 1e-9)
@test isapprox(res, easy_res15, rtol = 1e-9)
@test isapprox(res, easy_res16, rtol = 1e-9)
@test isapprox(res, easy_res142, rtol = 1e-9)
@test isapprox(res, easy_res143, rtol = 1e-9)
@test isapprox(res, easy_res144, rtol = 1e-9)
@test isapprox(res, easy_res145, rtol = 1e-9)
@test isapprox(res, easy_res146, rtol = 1e-9)
@test isapprox(res, easy_res147, rtol = 1e-9)

println("OOP adjoint sensitivities ")

Expand All @@ -203,14 +218,11 @@ _, easy_res = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(abstol = 1e-14,
reltol = 1e-14))
sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14))
_, easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = false,
abstol = 1e-14,
reltol = 1e-14))
sensealg = QuadratureAdjoint(autojacvec = false, abstol = 1e-14, reltol = 1e-14))
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand All @@ -224,17 +236,15 @@ _, easy_res3 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
@test easy_res32 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa
AbstractArray
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa AbstractArray
_, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint())
@test easy_res42 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa
AbstractArray
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa AbstractArray
_, easy_res5 = adjoint_sensitivities(soloop,
Kvaerno5(nlsolve = NLAnderson(), smooth_est = false),
t = t, dgdu_discrete = dg, abstol = 1e-12,
Expand All @@ -248,8 +258,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre
_, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg, abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
autojacvec = false),
sensealg = InterpolatingAdjoint(checkpointing = true, autojacvec = false),
checkpoints = soloop_nodense.t[1:5:end])

_, easy_res8 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg,
Expand Down Expand Up @@ -289,6 +298,39 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
reltol = 1e-14,
sensealg = GaussAdjoint(checkpointing = true),
checkpoints = soloop_nodense.t[1:5:end])

_, easy_res2_mc_quad = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res2_mc_interp = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res2_mc_back = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res6_mc_quad = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res6_mc_interp = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
autojacvec = SciMLSensitivity.MooncakeVJP()),
checkpoints = soloop_nodense.t[1:5:end])
_, easy_res6_mc_back = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))

@test isapprox(res, easy_res, rtol = 1e-10)
@test isapprox(res, easy_res2, rtol = 1e-10)
@test isapprox(res, easy_res22, rtol = 1e-10)
Expand All @@ -309,6 +351,12 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
@test isapprox(res, easy_res12, rtol = 1e-9)
@test isapprox(res, easy_res122, rtol = 1e-9)
@test isapprox(res, easy_res123, rtol = 1e-4)
@test isapprox(res, easy_res2_mc_quad, rtol=1e-9)
@test isapprox(res, easy_res2_mc_interp, rtol=1e-9)
@test isapprox(res, easy_res2_mc_back, rtol=1e-9)
@test isapprox(res, easy_res6_mc_quad, rtol=1e-4)
@test isapprox(res, easy_res6_mc_interp, rtol=1e-9)
@test isapprox(res, easy_res6_mc_back, rtol=1e-9)

println("Calculate adjoint sensitivities ")

Expand Down
Loading