diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b19874b3..224f81d4 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,6 +20,7 @@ jobs: matrix: version: - '1.7' + - '1.10' os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index d1b0ce8e..fff721f2 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" -Enzyme = "0.12" +Enzyme = "0.12.32" FillArrays = "1.3" ForwardDiff = "0.10.36" Functors = "0.4" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 5bd2aa0a..45b3c547 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -11,13 +11,39 @@ else using ..AdvancedVI: ADTypes, DiffResults end +function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params) + return restructure(params)::typeof(restructure.model) +end + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + Enzyme.API.runtimeActivity!(true) + ∇x = DiffResults.gradient(out) + fill!(∇x, zero(eltype(∇x))) + _, y = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) + ) + DiffResults.value!(out, y) + return out +end + function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - ∇θ = DiffResults.gradient(out) - fill!(∇θ, zero(T)) + ::ADTypes.AutoEnzyme, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult, +) + Enzyme.API.runtimeActivity!(true) + ∇x = DiffResults.gradient(out) + fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(θ, ∇θ) + Enzyme.ReverseWithPrimal, + Enzyme.Const(f), + Enzyme.Active, + Enzyme.Duplicated(x, ∇x), + Enzyme.Const(aux), ) DiffResults.value!(out, y) return out diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index dfd682d5..8ac1b645 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -41,18 +41,17 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif function value_and_gradient! end """ - stop_gradient(x) + restructure_ad_forward(adtype, restructure, params) -Stop the gradient from propagating to `x` if the selected ad backend supports it. -Otherwise, it is equivalent to `identity`. +Apply `restructure` to `params`. +This is an indirection for handling the type stability of `restructure`, as some AD backends require strict type stability in the AD path. # Arguments -- `x`: Input - -# Returns -- `x`: Same value as the input. +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +- `params`: Variational Parameters. """ -function stop_gradient end +restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) # Update for gradient descent step """ diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 552e6b93..1aab2e71 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -38,14 +38,14 @@ Functors.@functor MvLocationScale (location, scale) # is very inefficient. # begin struct RestructureMeanField{S<:Diagonal,D,L} - q::MvLocationScale{S,D,L} + model::MvLocationScale{S,D,L} end function (re::RestructureMeanField)(flat::AbstractVector) n_dims = div(length(flat), 2) location = first(flat, n_dims) scale = Diagonal(last(flat, n_dims)) - return MvLocationScale(location, scale, re.q.dist, re.q.scale_eps) + return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps) end function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index b7af17b1..e6f04ae8 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -93,8 +93,8 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp end function estimate_repgradelbo_ad_forward(params′, aux) - @unpack rng, obj, problem, restructure, q_stop = aux - q = restructure(params′) + @unpack rng, obj, problem, adtype, restructure, q_stop = aux + q = restructure_ad_forward(adtype, restructure, params′) samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(problem, samples) elbo = energy + entropy @@ -112,7 +112,14 @@ function estimate_gradient!( state, ) q_stop = restructure(params) - aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + aux = ( + rng=rng, + adtype=adtype, + obj=obj, + problem=prob, + restructure=restructure, + q_stop=q_stop, + ) value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) diff --git a/test/Project.toml b/test/Project.toml index d7212699..251869e7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -26,9 +26,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1, 1" Bijectors = "0.13" +DiffResults = "1.0" Distributions = "0.25.100" DistributionsAD = "0.6.45" -Enzyme = "0.12" +Enzyme = "0.12.32" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 0ca2223f..3a458e38 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,4 +1,19 @@ +AD_distributionsad = if VERSION >= v"1.10" + Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment + :Zygote => AutoZygote(), + ) +end + @testset "inference RepGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], @@ -9,12 +24,7 @@ :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), - ) + (adbackname, adtype) in AD_distributionsad seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -31,8 +41,8 @@ # where ρ = 1 - ημ, μ is the strong convexity constant. contraction_rate = 1 - η * strong_convexity - μ0 = Zeros(realtype, n_dims) - L0 = Diagonal(Ones(realtype, n_dims)) + μ0 = zeros(realtype, n_dims) + L0 = Diagonal(ones(realtype, n_dims)) q0 = TuringDiagMvNormal(μ0, diag(L0)) @testset "convergence" begin diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 56ddc7b5..b0007706 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,4 +1,19 @@ +AD_locationscale = if VERSION >= v"1.10" + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + ) +end + @testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], @@ -10,12 +25,7 @@ :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), - ) + (adbackname, adtype) in AD_locationscale seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 245c1544..79c81c52 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,4 +1,19 @@ +AD_locationscale_bijectors = if VERSION >= v"1.10" + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + ) +end + @testset "inference RepGradELBO VILocationScale Bijectors" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], @@ -10,12 +25,7 @@ :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), - ) + (adbackname, adtype) in AD_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 14d6ff69..380c2b9b 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -19,4 +19,23 @@ using Test @test ∇ ≈ (A + A') * λ / 2 @test f ≈ λ' * A * λ / 2 end + + @testset "$(adname) with auxiliary input" for (adname, adsymbol) in Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) + D = 10 + A = randn(D, D) + λ = randn(D) + b = randn(D) + grad_buf = DiffResults.GradientResult(λ) + f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) + AdvancedVI.value_and_gradient!(adsymbol, f, λ, (b=b,), grad_buf) + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A') * λ / 2 + b + @test f ≈ λ' * A * λ / 2 + dot(b, λ) + end end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index a6c45a75..5fec46ff 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -35,7 +35,10 @@ end @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats @testset for ad in [ - ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() + ADTypes.AutoForwardDiff(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoZygote(), + ADTypes.AutoEnzyme(), ] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) @@ -44,7 +47,7 @@ end obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) AdvancedVI.value_and_gradient!( ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) diff --git a/test/models/normal.jl b/test/models/normal.jl index 3aa95f4c..9fc6ae38 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -35,9 +35,7 @@ function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) σ0 = realtype(0.3) μ = Fill(realtype(5), n_dims) - #randn(rng, realtype, n_dims) σ = Fill(σ0, n_dims) - #log.(exp.(randn(rng, realtype, n_dims)) .+ 1) model = TestNormal(μ, Diagonal(σ .^ 2)) diff --git a/test/runtests.jl b/test/runtests.jl index 31028167..80194a43 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,7 @@ using DistributionsAD using LogDensityProblems using Optimisers using ADTypes -using Enzyme, ForwardDiff, ReverseDiff, Zygote +using ForwardDiff, ReverseDiff, Zygote, Enzyme using AdvancedVI