diff --git a/Project.toml b/Project.toml index caab92930..39d6ad5e7 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19" Distributed = "1" Distributions = "0.25" EllipsisNotation = "1" -Enzyme = "0.12" +Enzyme = "0.12.22" FiniteDiff = "2" ForwardDiff = "0.10" FunctionProperties = "0.1" diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 22d4401ef..ebd1c1d9e 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -680,7 +680,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, ytmp = _tmp5 end - tmp1 .= 0 # should be removed for dλ + tmp1 .= 0 + #Enzyme.make_zero!(tmp1) # should be removed for dλ ytmp .= y #if dgrad !== nothing @@ -699,12 +700,15 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, # tmp3 = dy #else tmp3 .= 0 + #Enzyme.make_zero!(tmp3) #end vec(tmp4) .= vec(λ) isautojacvec = get_jacvec(sensealg) + Enzyme.make_zero!(_tmp6) + if inplace_sensitivity(S) if W === nothing Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..717396059 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -497,7 +497,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) tmp3, tmp4, tmp6 = paramjac_config vtmp4 = vec(tmp4) vtmp4 .= λ - out .= 0 + Enzyme.make_zero!(out) + Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index ea062c17a..4ec165f1f 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -293,7 +293,8 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) elseif sensealg.autojacvec isa EnzymeVJP tmp3, tmp4, tmp6 = paramjac_config tmp4 .= λ - out .= 0 + Enzyme.make_zero!(out) + Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),