diff --git a/Project.toml b/Project.toml index 3219992c..6322bfa7 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" -Enzyme = "0.12.32" +Enzyme = "0.13" FillArrays = "1.3" ForwardDiff = "0.10.36" Functors = "0.4" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 45b3c547..7fa05e56 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -18,11 +18,10 @@ end function AdvancedVI.value_and_gradient!( ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) - Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) ) DiffResults.value!(out, y) return out @@ -35,11 +34,10 @@ function AdvancedVI.value_and_gradient!( aux, out::DiffResults.MutableDiffResult, ) - Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true) Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x),