Skip to content

Commit

Permalink
Safety make_zero! on repeated Enzyme calls with caches
Browse files Browse the repository at this point in the history
This should ensure that the caches are always zero'd for the derivatives.
  • Loading branch information
ChrisRackauckas committed Jun 28, 2024
1 parent 456ea4b commit 0c56bec
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19"
Distributed = "1"
Distributions = "0.25"
EllipsisNotation = "1"
Enzyme = "0.12"
Enzyme = "0.12.12"
FiniteDiff = "2"
ForwardDiff = "0.10"
FunctionProperties = "0.1"
Expand Down
2 changes: 2 additions & 0 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,

isautojacvec = get_jacvec(sensealg)

Enzyme.make_zero!(_tmp6)

if inplace_sensitivity(S)
if W === nothing
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6),
Expand Down
1 change: 1 addition & 0 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
vtmp4 = vec(tmp4)
vtmp4 .= λ
out .= 0
Enzyme.make_zero!(tmp6)
Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Expand Down
1 change: 1 addition & 0 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
tmp3, tmp4, tmp6 = paramjac_config
tmp4 .= λ
out .= 0
Enzyme.make_zero!(tmp6)
Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Expand Down

0 comments on commit 0c56bec

Please sign in to comment.