diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml new file mode 100644 index 00000000..cf5dcdf3 --- /dev/null +++ b/.buildkite/pipeline.yml @@ -0,0 +1,18 @@ +steps: + - label: "CUDA with julia {{matrix.julia}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + agents: + queue: "juliagpu" + cuda: "*" + timeout_in_minutes: 60 + env: + GROUP: "GPU" + ADVANCEDVI_TEST_CUDA: "true" + matrix: + setup: + julia: + - "1.10" diff --git a/Project.toml b/Project.toml index 6322bfa7..572ea144 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -24,50 +25,47 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" -AdvancedVIForwardDiffExt = "ForwardDiff" -AdvancedVIReverseDiffExt = "ReverseDiff" -AdvancedVITapirExt = "Tapir" -AdvancedVIZygoteExt = "Zygote" [compat] -ADTypes = "0.1, 0.2, 1" +ADTypes = "1" Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" DiffResults = "1" +DifferentiationInterface = "0.6" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" Enzyme = "0.13" FillArrays = "1.3" -ForwardDiff = "0.10.36" +ForwardDiff = "0.10" Functors = "0.4" LinearAlgebra = "1" LogDensityProblems = "2" +Mooncake = "0.4" Optimisers = "0.2.16, 0.3" ProgressMeter = "1.6" Random = "1" Requires = "1.0" -ReverseDiff = "1.15.1" +ReverseDiff = "1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -Tapir = "0.2" -Zygote = "0.6.63" +Zygote = "0.6" julia = "1.7" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/Project.toml b/docs/Project.toml index f42b21bc..8dc25a3b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,7 +14,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] -ADTypes = "0.1.6" +ADTypes = "1" AdvancedVI = "0.3" Bijectors = "0.13.6" Distributions = "0.25" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 3b68d531..a4119c3b 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -1,4 +1,3 @@ - module AdvancedVIEnzymeExt if isdefined(Base, :get_extension) @@ -15,21 +14,6 @@ function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, pa return restructure(params)::typeof(restructure.model) end -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) - ∇x = DiffResults.gradient(out) - fill!(∇x, zero(eltype(∇x))) - _, y = Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), - Enzyme.Const(f), - Enzyme.Active, - Enzyme.Duplicated(x, ∇x), - ) - DiffResults.value!(out, y) - return out -end - function AdvancedVI.value_and_gradient!( ::ADTypes.AutoEnzyme, f, diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index 6904fa7a..e69de29b 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -1,42 +0,0 @@ - -module AdvancedVIForwardDiffExt - -if isdefined(Base, :get_extension) - using ForwardDiff - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults -else - using ..ForwardDiff - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults -end - -getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, - f, - x::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, -) - chunk_size = getchunksize(ad) - config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, x) - else - ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size)) - end - ForwardDiff.gradient!(out, f, x, config) - return out -end - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, - f, - x::AbstractVector, - aux, - out::DiffResults.MutableDiffResult, -) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) -end - -end diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl new file mode 100644 index 00000000..e69de29b diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 9cde91a1..e69de29b 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -1,36 +0,0 @@ - -module AdvancedVIReverseDiffExt - -if isdefined(Base, :get_extension) - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults - using ReverseDiff -else - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults - using ..ReverseDiff -end - -# ReverseDiff without compiled tape -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, - f, - x::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, -) - tp = ReverseDiff.GradientTape(f, x) - ReverseDiff.gradient!(out, tp, x) - return out -end - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) -end - -end diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl deleted file mode 100644 index 459ef7da..00000000 --- a/ext/AdvancedVITapirExt.jl +++ /dev/null @@ -1,37 +0,0 @@ -module AdvancedVITapirExt - -if isdefined(Base, :get_extension) - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults - using Tapir -else - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults - using ..Tapir -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoTapir, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) - rule = Tapir.build_rrule(f, x) - y, g = Tapir.value_and_gradient!!(rule, f, x) - DiffResults.value!(out, y) - DiffResults.gradient!(out, last(g)) - return out -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoTapir, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - rule = Tapir.build_rrule(f, x, aux) - y, g = Tapir.value_and_gradient!!(rule, f, x, aux) - DiffResults.value!(out, y) - DiffResults.gradient!(out, g[2]) - return out -end - -end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 2cdd8392..e69de29b 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -1,36 +0,0 @@ - -module AdvancedVIZygoteExt - -if isdefined(Base, :get_extension) - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults - using ChainRulesCore - using Zygote -else - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults - using ..ChainRulesCore - using ..Zygote -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) - y, back = Zygote.pullback(f, x) - ∇x = back(one(y)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, only(∇x)) - return out -end - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) -end - -end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 4c3c39cc..aebe765e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -16,16 +16,17 @@ using LinearAlgebra using LogDensityProblems -using ADTypes, DiffResults +using ADTypes +using DiffResults +using DifferentiationInterface using ChainRulesCore using FillArrays using StatsBase -# derivatives +# Derivatives """ - value_and_gradient!(ad, f, x, out) value_and_gradient!(ad, f, x, aux, out) Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. @@ -38,7 +39,14 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif - `aux`: Auxiliary input passed to `f`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ -function value_and_gradient! end +function value_and_gradient!( + ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult +) + grad_buf = DiffResults.gradient(out) + y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux)) + DiffResults.value!(out, y) + return out +end """ restructure_ad_forward(adtype, restructure, params) @@ -131,7 +139,7 @@ function estimate_objective end export estimate_objective """ - estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) + estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` @@ -141,7 +149,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. - `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `λ`: Variational parameters to evaluate the gradient on. +- `params`: Variational parameters to evaluate the gradient on. - `restructure`: Function that reconstructs the variational approximation from `λ`. - `obj_state`: Previous state of the objective. diff --git a/src/optimize.jl b/src/optimize.jl index 9a748907..8ef9db76 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -42,7 +42,7 @@ The arguments are as follows: - `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. - `gradient`: The estimated (possibly stochastic) gradient. -`cb` can return a `NamedTuple` containing some additional information computed within `cb`. +`callback` can return a `NamedTuple` containing some additional information computed within `cb`. This will be appended to the statistic of the current corresponding iteration. Otherwise, just return `nothing`. diff --git a/test/Project.toml b/test/Project.toml index ca0fc384..bbf9c4c6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,9 +2,9 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -26,7 +26,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1, 1" Bijectors = "0.13" -DiffResults = "1.0" +DiffResults = "1" +DifferentiationInterface = "0.6" Distributions = "0.25.111" DistributionsAD = "0.6.45" FillArrays = "1.6.1" @@ -41,6 +42,7 @@ ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" +StatsBase = "0.34" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 94da09bc..4086a205 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -5,12 +5,14 @@ AD_distributionsad = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_distributionsad[:Enzyme] = AutoEnzyme() + AD_distributionsad[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference RepGradELBO DistributionsAD" begin diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 9e254b6a..1ca31885 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -5,15 +5,17 @@ AD_locationscale = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_locationscale[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_locationscale[:Enzyme] = AutoEnzyme() + AD_locationscale[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end -@testset "inference RepGradELBO VILocationScale" begin +@testset "inference ScoreGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 731326f3..167fe389 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -5,12 +5,14 @@ AD_locationscale_bijectors = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_locationscale_bijectors[:Enzyme] = AutoEnzyme() + AD_locationscale_bijectors[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference RepGradELBO VILocationScale Bijectors" begin diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 700dda6d..1de7af1d 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -1,19 +1,19 @@ -AD_distributionsad = Dict( +AD_scoregradelbo_distributionsad = Dict( :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment :Zygote => AutoZygote(), ) if @isdefined(Tapir) - AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) + AD_scoregradelbo_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) end #if @isdefined(Enzyme) -# AD_distributionsad[:Enzyme] = AutoEnzyme() +# AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme() #end -@testset "inference RepGradELBO DistributionsAD" begin +@testset "inference ScoreGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), @@ -23,7 +23,7 @@ end :ScoreGradELBOStickingTheLanding => ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_distributionsad + (adbackname, adtype) in AD_scoregradelbo_distributionsad seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index ef49713b..f0073d7c 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -1,16 +1,18 @@ -AD_locationscale = Dict( +AD_scoregradelbo_locationscale = Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_locationscale[:Enzyme] = AutoEnzyme() + AD_scoregradelbo_locationscale[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference ScoreGradELBO VILocationScale" begin diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index 088130aa..bee8234a 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -1,16 +1,16 @@ -AD_locationscale_bijectors = Dict( +AD_scoregradelbo_locationscale_bijectors = Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), #:Zygote => AutoZygote(), ) #if @isdefined(Tapir) -# AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) +# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) #end if @isdefined(Enzyme) - AD_locationscale_bijectors[:Enzyme] = AutoEnzyme() + AD_scoregradelbo_locationscale_bijectors[:Enzyme] = AutoEnzyme() end @testset "inference ScoreGradELBO VILocationScale Bijectors" begin @@ -24,7 +24,7 @@ end :ScoreGradELBOStickingTheLanding => ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_locationscale_bijectors + (adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index e8f4da4e..713a0f56 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -17,19 +17,6 @@ end @testset "ad" begin @testset "$(adname)" for (adname, adtype) in interface_ad_backends - D = 10 - A = randn(D, D) - λ = randn(D) - grad_buf = DiffResults.GradientResult(λ) - f(λ′) = λ′' * A * λ′ / 2 - AdvancedVI.value_and_gradient!(adtype, f, λ, grad_buf) - ∇ = DiffResults.gradient(grad_buf) - f = DiffResults.value(grad_buf) - @test ∇ ≈ (A + A') * λ / 2 - @test f ≈ λ' * A * λ / 2 - end - - @testset "$(adname) with auxiliary input" for (adname, adtype) in interface_ad_backends D = 10 A = randn(D, D) λ = randn(D) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index baf1499a..afd6249e 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -37,14 +37,19 @@ end ad_backends = [ ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() ] - if @isdefined(Tapir) - push!(ad_backends, AutoTapir(; safe_mode=false)) + if @isdefined(Mooncake) + push!(ad_backends, AutoMooncake(; config=nothing)) end if @isdefined(Enzyme) - push!(ad_backends, AutoEnzyme()) + push!( + ad_backends, + AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ), + ) end - @testset for ad in ad_backends + @testset for adtype in ad_backends q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) @@ -52,9 +57,11 @@ end obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) + aux = ( + rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype + ) AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol = 1e-5 diff --git a/test/interface/scoregradelbo.jl b/test/interface/scoregradelbo.jl index a800f744..8a6ebb14 100644 --- a/test/interface/scoregradelbo.jl +++ b/test/interface/scoregradelbo.jl @@ -26,32 +26,3 @@ using Test @test elbo ≈ elbo_ref rtol = 0.2 end end - -@testset "interface ScoreGradELBO STL variance reduction" begin - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = normal_meanfield(rng, Float64) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - @testset for ad in [ - ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() - ] - q_true = MeanFieldGaussian( - Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) - ) - params, re = Optimisers.destructure(q_true) - obj = ScoreGradELBO( - 1000; entropy=StickingTheLandingEntropy(), baseline_history=[0.0] - ) - out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) - AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_scoregradelbo_ad_forward, params, aux, out - ) - value = DiffResults.value(out) - grad = DiffResults.gradient(out) - @test norm(grad) ≈ 0 atol = 10 # high tolerance required. - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 85bec3a7..7c0e3129 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,9 +4,12 @@ using Test: @testset, @test using Base.Iterators using Bijectors +using DiffResults using Distributions using FillArrays using LinearAlgebra +using LogDensityProblems +using Optimisers using PDMats using Pkg using Random, StableRNGs @@ -18,14 +21,13 @@ using Functors using DistributionsAD @functor TuringDiagMvNormal -using LogDensityProblems -using Optimisers using ADTypes using ForwardDiff, ReverseDiff, Zygote if VERSION >= v"1.10" - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + Pkg.add("Enzyme") + using Mooncake using Enzyme end