From e228a702e9ab58f47fd14217fb59a2cb9a9e8212 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 9 Jun 2024 17:26:01 -0400 Subject: [PATCH 1/7] Safety make_zero! on repeated Enzyme calls with caches This should ensure that the caches are always zero'd for the derivatives. --- Project.toml | 2 +- src/derivative_wrappers.jl | 2 ++ src/gauss_adjoint.jl | 1 + src/quadrature_adjoint.jl | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index caab92930..55a88942f 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19" Distributed = "1" Distributions = "0.25" EllipsisNotation = "1" -Enzyme = "0.12" +Enzyme = "0.12.12" FiniteDiff = "2" ForwardDiff = "0.10" FunctionProperties = "0.1" diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 22d4401ef..bc334db71 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -705,6 +705,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, isautojacvec = get_jacvec(sensealg) + Enzyme.make_zero!(_tmp6) + if inplace_sensitivity(S) if W === nothing Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..f1d13b03d 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -498,6 +498,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) vtmp4 = vec(tmp4) vtmp4 .= λ out .= 0 + Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index ea062c17a..648560de4 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -294,6 +294,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) tmp3, tmp4, tmp6 = paramjac_config tmp4 .= λ out .= 0 + Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), From 77ca11f56a00e4bf75342cc030980a2675503122 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 10 Jun 2024 12:22:02 -0400 Subject: [PATCH 2/7] a few more --- src/derivative_wrappers.jl | 4 ++-- src/gauss_adjoint.jl | 2 +- src/quadrature_adjoint.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index bc334db71..6e3bdbfe8 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -680,7 +680,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, ytmp = _tmp5 end - tmp1 .= 0 # should be removed for dλ + Enzyme.make_zero!(tmp1) # should be removed for dλ ytmp .= y #if dgrad !== nothing @@ -698,7 +698,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, #if dy !== nothing # tmp3 = dy #else - tmp3 .= 0 + Enzyme.make_zero!(tmp3) #end vec(tmp4) .= vec(λ) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index f1d13b03d..717396059 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -497,7 +497,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) tmp3, tmp4, tmp6 = paramjac_config vtmp4 = vec(tmp4) vtmp4 .= λ - out .= 0 + Enzyme.make_zero!(out) Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 648560de4..4ec165f1f 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -293,7 +293,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) elseif sensealg.autojacvec isa EnzymeVJP tmp3, tmp4, tmp6 = paramjac_config tmp4 .= λ - out .= 0 + Enzyme.make_zero!(out) Enzyme.make_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, From 320cdfc699ea0722329eda3e5dae3aec2dcf0df0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 5 Jul 2024 09:34:39 -0400 Subject: [PATCH 3/7] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 55a88942f..026492dca 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19" Distributed = "1" Distributions = "0.25" EllipsisNotation = "1" -Enzyme = "0.12.12" +Enzyme = "0.12.21" FiniteDiff = "2" ForwardDiff = "0.10" FunctionProperties = "0.1" From 7f34cb84c4e8496e738f59cf692de0cb51e4ca2f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 5 Jul 2024 14:36:26 -0400 Subject: [PATCH 4/7] revert failing make_zeros due to https://github.com/EnzymeAD/Enzyme.jl/issues/1611 --- src/derivative_wrappers.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 6e3bdbfe8..ebd1c1d9e 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -680,7 +680,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, ytmp = _tmp5 end - Enzyme.make_zero!(tmp1) # should be removed for dλ + tmp1 .= 0 + #Enzyme.make_zero!(tmp1) # should be removed for dλ ytmp .= y #if dgrad !== nothing @@ -698,7 +699,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, #if dy !== nothing # tmp3 = dy #else - Enzyme.make_zero!(tmp3) + tmp3 .= 0 + #Enzyme.make_zero!(tmp3) #end vec(tmp4) .= vec(λ) From ded63ae3652cb0920bdcd3a9479d070ce14b6008 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 9 Jul 2024 14:11:09 +0200 Subject: [PATCH 5/7] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 026492dca..39d6ad5e7 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19" Distributed = "1" Distributions = "0.25" EllipsisNotation = "1" -Enzyme = "0.12.21" +Enzyme = "0.12.22" FiniteDiff = "2" ForwardDiff = "0.10" FunctionProperties = "0.1" From 7bc9e2cf1121702356684bc92cb54689a5d3b827 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 14 Jul 2024 18:47:03 +0200 Subject: [PATCH 6/7] Only duplicate and make_zero! if enzyme func is non-constant --- src/derivative_wrappers.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index ebd1c1d9e..8f9d82dff 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -707,11 +707,16 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, isautojacvec = get_jacvec(sensealg) - Enzyme.make_zero!(_tmp6) + if Core.Compiler.isconstType(_tmp6) + Enzyme.make_zero!(_tmp6) + _f = Enzyme.Duplicated(S.diffcache.pf, _tmp6) + else + _f = S.diffcache.pf + end if inplace_sensitivity(S) if W === nothing - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), + Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, From de62cab5e3cf067df512f961793789e219e6ba1c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 21 Jul 2024 17:13:34 -0400 Subject: [PATCH 7/7] Revert "Only duplicate and make_zero! if enzyme func is non-constant" This reverts commit 7bc9e2cf1121702356684bc92cb54689a5d3b827. --- src/derivative_wrappers.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 8f9d82dff..ebd1c1d9e 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -707,16 +707,11 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, isautojacvec = get_jacvec(sensealg) - if Core.Compiler.isconstType(_tmp6) - Enzyme.make_zero!(_tmp6) - _f = Enzyme.Duplicated(S.diffcache.pf, _tmp6) - else - _f = S.diffcache.pf - end + Enzyme.make_zero!(_tmp6) if inplace_sensitivity(S) if W === nothing - Enzyme.autodiff(Enzyme.Reverse, _f, + Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup,