From be5d7b255ca06672ca5c2966c5c984df11a0a53b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 20 Sep 2024 10:20:44 +0900 Subject: [PATCH] Revert "Revert "Adapt to pending Enzyme breaking change"" (#94) * Revert "Revert "Adapt to pending Enzyme breaking change (#92)" (#93)" This reverts commit 5e9b84bc38ba36f38295cb4319999f65e5dac347. * fix errors, bump Enzyme compat in test and run formatter * remove Enzyme compat bound in tests like Tapir --- Project.toml | 2 +- ext/AdvancedVIEnzymeExt.jl | 9 +++++---- test/Project.toml | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 3219992c..6322bfa7 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" -Enzyme = "0.12.32" +Enzyme = "0.13" FillArrays = "1.3" ForwardDiff = "0.10.36" Functors = "0.4" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 45b3c547..3b68d531 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -18,11 +18,13 @@ end function AdvancedVI.value_and_gradient!( ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) - Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), + Enzyme.Const(f), + Enzyme.Active, + Enzyme.Duplicated(x, ∇x), ) DiffResults.value!(out, y) return out @@ -35,11 +37,10 @@ function AdvancedVI.value_and_gradient!( aux, out::DiffResults.MutableDiffResult, ) - Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x), diff --git a/test/Project.toml b/test/Project.toml index 018198d1..ca0fc384 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,6 @@ Bijectors = "0.13" DiffResults = "1.0" Distributions = "0.25.111" DistributionsAD = "0.6.45" -Enzyme = "0.12.32" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5"