From 6e09d4dff68599e490e1ced201f30a05e6e248b8 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Oct 2024 13:47:13 -0700 Subject: [PATCH 1/3] bump Enzyme version, update Enzyme interface --- Project.toml | 2 +- ext/AdvancedVIEnzymeExt.jl | 21 ++++++++++++--------- test/runtests.jl | 5 +---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 5cddb1d73..295863d62 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" DocStringExtensions = "0.8, 0.9" -Enzyme = "0.11" +Enzyme = "0.13" LinearAlgebra = "1.6" ForwardDiff = "0.10.3" Flux = "0.14" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 025c3901f..529483ab3 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -23,19 +23,22 @@ function AdvancedVI.grad!( out::DiffResults.MutableDiffResult, args... ) - f(θ) = - if (q isa Distributions.Distribution) - -vo(alg, AdvancedVI.update(q, θ), model, args...) - else - -vo(alg, q(θ), model, args...) - end - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 + f(θ) = if (q isa Distributions.Distribution) + -vo(alg, AdvancedVI.update(q, θ), model, args...) + else + -vo(alg, q(θ), model, args...) + end + y = f(θ) DiffResults.value!(out, y) dy = DiffResults.gradient(out) fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy)) + Enzyme.autodiff( + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), + Enzyme.Const(f), + Enzyme.Active, + Enzyme.Duplicated(θ, dy) + ) return out end diff --git a/test/runtests.jl b/test/runtests.jl index 71a611e03..e547af913 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,8 +6,6 @@ using ReverseDiff: ReverseDiff using Tracker: Tracker using Zygote: Zygote using Enzyme: Enzyme -Enzyme.API.runtimeActivity!(true); -Enzyme.API.typeWarning!(false); using AdvancedVI @@ -22,7 +20,7 @@ include("optimisers.jl") AutoReverseDiff(), AutoTracker(), AutoZygote(), - # AutoEnzyme() # results in incorrect result + AutoEnzyme() ] target = MvNormal(ones(2)) logπ(z) = logpdf(target, z) @@ -42,5 +40,4 @@ include("optimisers.jl") xs = rand(target, 10) @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 - end From 5cafd5fd5294115d38afa4d05d9d22bb1a883648 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Oct 2024 13:59:12 -0700 Subject: [PATCH 2/3] drop testing on Julia 1.6 for v0.2 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9731f20c2..9f3791724 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: - os: macOS-latest arch: x86 include: - - version: '1.6' + - version: '1.10' os: ubuntu-latest arch: x64 - os: ubuntu-latest From c898d8e0606aae877c819102a36c70b24930dde3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Oct 2024 14:25:04 -0700 Subject: [PATCH 3/3] fix disable testing on Enzyme --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index e547af913..63f40c80a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,7 @@ include("optimisers.jl") AutoReverseDiff(), AutoTracker(), AutoZygote(), - AutoEnzyme() + # AutoEnzyme() ] target = MvNormal(ones(2)) logπ(z) = logpdf(target, z)