From 0c56becd34294d04e7c6c2bbae124bf7ac5956a5 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 9 Jun 2024 17:26:01 -0400 Subject: [PATCH] Safety make_zero! on repeated Enzyme calls with caches This should ensure that the caches are always zero'd for the derivatives. --- Project.toml | 2 +- src/derivative_wrappers.jl | 2 ++ src/gauss_adjoint.jl | 1 + src/quadrature_adjoint.jl | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 123e97c3a..2c41e95b0 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19" Distributed = "1" Distributions = "0.25" EllipsisNotation = "1" -Enzyme = "0.12" +Enzyme = "0.12.12" FiniteDiff = "2" ForwardDiff = "0.10" FunctionProperties = "0.1" diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 5d5f3464f..55d84dc6d 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -705,6 +705,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, isautojacvec = get_jacvec(sensealg) + Enzyme.make_zero!(_tmp6) + if inplace_sensitivity(S) if W === nothing Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..f1d13b03d 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -498,6 +498,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) vtmp4 = vec(tmp4) vtmp4 .= λ out .= 0 + Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index ea062c17a..648560de4 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -294,6 +294,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) tmp3, tmp4, tmp6 = paramjac_config tmp4 .= λ out .= 0 + Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),