diff --git a/Project.toml b/Project.toml index 0821da47c..c4706497a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.71.2" +version = "7.72.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -42,6 +42,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[weakdeps] +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" + +[extensions] +SciMLSensitivityMooncakeExt = "Mooncake" + [compat] ADTypes = "1.9" Accessors = "0.1.36" @@ -71,6 +77,7 @@ LinearSolve = "2" Lux = "1" Markdown = "1.10" ModelingToolkit = "9.42" +Mooncake = "0.4.52" NLsolve = "4.5.1" NonlinearSolve = "3.0.1" Optimization = "4" @@ -110,6 +117,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -123,4 +131,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] +test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] diff --git a/docs/src/manual/differential_equation_sensitivities.md b/docs/src/manual/differential_equation_sensitivities.md index 433f0bf4a..c54dadc1e 100644 --- a/docs/src/manual/differential_equation_sensitivities.md +++ b/docs/src/manual/differential_equation_sensitivities.md @@ -212,6 +212,7 @@ ZygoteVJP EnzymeVJP TrackerVJP ReverseDiffVJP +MooncakeVJP ``` ## More Details on Sensitivity Algorithm Choices diff --git a/ext/SciMLSensitivityMooncakeExt.jl b/ext/SciMLSensitivityMooncakeExt.jl new file mode 100644 index 000000000..7f0f8f35a --- /dev/null +++ b/ext/SciMLSensitivityMooncakeExt.jl @@ -0,0 +1,22 @@ +module SciMLSensitivityMooncakeExt + +using SciMLSensitivity, Mooncake +import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded + +function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t) + dy_mem = zero(y) + λ_mem = zero(y) + cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t) + return cache, pf, λ_mem, dy_mem +end + +function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ) + cache, pf, λ_mem, dy_mem = paramjac_config + λ_mem .= λ + dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t) + y_grad = cache.tangents[3] + p_grad = cache.tangents[4] + return dy, y_grad, p_grad +end + +end diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 7e83ceb51..8b708f00c 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg) pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE) paramjac_config = (paramjac_config..., Enzyme.make_zero(pf)) + elseif autojacvec isa MooncakeVJP + pf = get_pf(autojacvec, prob, unwrappedf) + paramjac_config = get_paramjac_config(MooncakeLoaded(), autojacvec, pf, p, f, y, _t) elseif SciMLBase.has_paramjac(f) || quad || !(autojacvec isa Bool) || autojacvec isa EnzymeVJP paramjac_config = nothing @@ -460,6 +463,15 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p, f, y, _p, _t; numindvar, return paramjac_config end +# Dispatched on inside extension. +struct MooncakeLoaded end + +function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t) + msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " * + "`using Mooncake` to use this functionality" + error(msg) +end + function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing, isRODE = nothing) nothing @@ -492,6 +504,41 @@ function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE) end end +function get_pf(::MooncakeVJP, prob, _f) + isinplace = DiffEqBase.isinplace(prob) + isRODE = isa(prob, RODEProblem) + pf = let f = _f + if isinplace && isRODE + function (out, u, _p, t, W) + f(out, u, _p, t, W) + return out + end + elseif isinplace + function (out, u, _p, t) + f(out, u, _p, t) + return out + end + elseif !isinplace && isRODE + function (out, u, _p, t, W) + out .= f(u, _p, t, W) + return out + end + else + # !isinplace + function (out, u, _p, t) + out .= f(u, _p, t) + return out + end + end + end +end + +function mooncake_run_ad(paramjac_config, y, p, t, λ) + msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " * + "`using Mooncake` to use this functionality" + error(msg) +end + function getprob(S::SensitivityFunction) (S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob end diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 1a0b81dfb..2b1f48849 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -751,6 +751,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, return end +function _vecjacobian!(dλ, y, λ, p, t, S::SensitivityFunction, ::MooncakeVJP, dgrad, dy, W) + _dy, y_grad, p_grad = mooncake_run_ad(S.diffcache.paramjac_config, y, p, t, λ) + dy !== nothing && recursive_copyto!(dy, _dy) + dλ !== nothing && recursive_copyto!(dλ, y_grad) + dgrad !== nothing && recursive_copyto!(dgrad, p_grad) + return +end + function jacNoise!(λ, y, p, t, S::SensitivityFunction; dgrad = nothing, dλ = nothing, dy = nothing) _jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 238d3f2ba..11379bd7a 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -428,6 +428,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) end paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf) pJ = nothing + elseif sensealg.autojacvec isa MooncakeVJP + pf = get_pf(sensealg.autojacvec, prob, f) + paramjac_config = get_paramjac_config( + MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2]) + pJ = nothing elseif isautojacvec # Zygote paramjac_config = nothing pf = nothing @@ -500,6 +505,9 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + elseif sensealg.autojacvec isa MooncakeVJP + _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) + out .= p_grad else error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint") end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index fea456345..27022f1da 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -235,6 +235,11 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) end paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf) pJ = nothing + elseif sensealg.autojacvec isa MooncakeVJP + pf = get_pf(sensealg.autojacvec, prob, f) + paramjac_config = get_paramjac_config( + MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2]) + pJ = nothing elseif isautojacvec # Zygote paramjac_config = nothing pf = nothing @@ -288,6 +293,9 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) else out[:] .= vec(tmp[1]) end + elseif sensealg.autojacvec isa MooncakeVJP + _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) + out .= p_grad elseif sensealg.autojacvec isa EnzymeVJP tmp3, tmp4, tmp6 = paramjac_config tmp4 .= λ diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 133267f7a..186675454 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1226,6 +1226,23 @@ struct ReverseDiffVJP{compile} <: VJPChoice ReverseDiffVJP(compile = false) = new{compile}() end +""" +```julia +MooncakeVJP <: VJPChoice +``` + +Uses Mooncake.jl to compute the vector-Jacobian products. + +Does not support GPUs (CuArrays). + +## Constructor + +```julia +MooncakeVJP() +``` +""" +struct MooncakeVJP <: VJPChoice end + @inline convert_tspan(::ForwardDiffSensitivity{CS, CTS}) where {CS, CTS} = CTS @inline convert_tspan(::Any) = nothing @inline function alg_autodiff(alg::AbstractSensitivityAlgorithm{ diff --git a/test/adjoint.jl b/test/adjoint.jl index 4e6adc168..4df7dad9c 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -135,6 +135,14 @@ _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussAdjoint()) +_, easy_res15 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP())) +_, easy_res16 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP())) _, easy_res142 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, @@ -158,6 +166,10 @@ _, easy_res146 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discret sensealg = GaussAdjoint(checkpointing = true, autojacvec = false), checkpoints = sol.t[1:500:end]) +_, easy_res147 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP())) adj_prob = ODEAdjointProblem(sol, QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.ReverseDiffVJP()), @@ -189,11 +201,14 @@ res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12) @test isapprox(res, easy_res12, rtol = 1e-9) @test isapprox(res, easy_res13, rtol = 1e-9) @test isapprox(res, easy_res14, rtol = 1e-9) +@test isapprox(res, easy_res15, rtol = 1e-9) +@test isapprox(res, easy_res16, rtol = 1e-9) @test isapprox(res, easy_res142, rtol = 1e-9) @test isapprox(res, easy_res143, rtol = 1e-9) @test isapprox(res, easy_res144, rtol = 1e-9) @test isapprox(res, easy_res145, rtol = 1e-9) @test isapprox(res, easy_res146, rtol = 1e-9) +@test isapprox(res, easy_res147, rtol = 1e-9) println("OOP adjoint sensitivities ") @@ -203,14 +218,11 @@ _, easy_res = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, _, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, - sensealg = QuadratureAdjoint(abstol = 1e-14, - reltol = 1e-14)) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14)) _, easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, - sensealg = QuadratureAdjoint(autojacvec = false, - abstol = 1e-14, - reltol = 1e-14)) + sensealg = QuadratureAdjoint(autojacvec = false, abstol = 1e-14, reltol = 1e-14)) _, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, @@ -224,8 +236,7 @@ _, easy_res3 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, @test easy_res32 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, - sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa - AbstractArray + sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa AbstractArray _, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, @@ -233,8 +244,7 @@ _, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, @test easy_res42 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, - sensealg = BacksolveAdjoint(autojacvec = false))[1] isa - AbstractArray + sensealg = BacksolveAdjoint(autojacvec = false))[1] isa AbstractArray _, easy_res5 = adjoint_sensitivities(soloop, Kvaerno5(nlsolve = NLAnderson(), smooth_est = false), t = t, dgdu_discrete = dg, abstol = 1e-12, @@ -248,8 +258,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre _, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, - sensealg = InterpolatingAdjoint(checkpointing = true, - autojacvec = false), + sensealg = InterpolatingAdjoint(checkpointing = true, autojacvec = false), checkpoints = soloop_nodense.t[1:5:end]) _, easy_res8 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg, @@ -289,6 +298,39 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc reltol = 1e-14, sensealg = GaussAdjoint(checkpointing = true), checkpoints = soloop_nodense.t[1:5:end]) + +_, easy_res2_mc_quad = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint( + abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP())) +_, easy_res2_mc_interp = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP())) +_, easy_res2_mc_back = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP())) +_, easy_res6_mc_quad = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, + dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint( + abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP())) +_, easy_res6_mc_interp = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, + dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(checkpointing = true, + autojacvec = SciMLSensitivity.MooncakeVJP()), + checkpoints = soloop_nodense.t[1:5:end]) +_, easy_res6_mc_back = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, + dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP())) + @test isapprox(res, easy_res, rtol = 1e-10) @test isapprox(res, easy_res2, rtol = 1e-10) @test isapprox(res, easy_res22, rtol = 1e-10) @@ -309,6 +351,12 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc @test isapprox(res, easy_res12, rtol = 1e-9) @test isapprox(res, easy_res122, rtol = 1e-9) @test isapprox(res, easy_res123, rtol = 1e-4) +@test isapprox(res, easy_res2_mc_quad, rtol = 1e-9) +@test isapprox(res, easy_res2_mc_interp, rtol = 1e-9) +@test isapprox(res, easy_res2_mc_back, rtol = 1e-9) +@test isapprox(res, easy_res6_mc_quad, rtol = 1e-4) +@test isapprox(res, easy_res6_mc_interp, rtol = 1e-9) +@test isapprox(res, easy_res6_mc_back, rtol = 1e-9) println("Calculate adjoint sensitivities ") diff --git a/test/runtests.jl b/test/runtests.jl index 6fa3db51b..3e9af1197 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SciMLSensitivity, SafeTestsets using Test, Pkg +import Mooncake const GROUP = get(ENV, "GROUP", "All")