diff --git a/Project.toml b/Project.toml index b0763846..3219992c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.3.1" +version = "0.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index bfce495b..e8bebf9d 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,23 +1,18 @@ -AD_distributionsad = if VERSION >= v"1.10" - Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment - :Zygote => AutoZygote(), - ) -end +AD_distributionsad = Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment + :Zygote => AutoZygote(), +) if @isdefined(Tapir) AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) end +if @isdefined(Enzyme) + AD_distributionsad[:Enzyme] = AutoEnzyme() +end + @testset "inference RepGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index d0e7b6d4..5ce92809 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,24 +1,18 @@ -AD_locationscale = if VERSION >= v"1.10" - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - :Tapir => AutoTapir(; safe_mode=false), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - ) -end +AD_locationscale = Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), +) if @isdefined(Tapir) AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) end +if @isdefined(Enzyme) + AD_locationscale[:Enzyme] = AutoEnzyme() +end + @testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index ff37b82a..0fbe5ab7 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,24 +1,18 @@ -AD_locationscale_bijectors = if VERSION >= v"1.10" - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - :Tapir => AutoTapir(; safe_mode=false), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - ) -end +AD_locationscale_bijectors = Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), +) if @isdefined(Tapir) AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) end +if @isdefined(Enzyme) + AD_locationscale_bijectors[:Enzyme] = AutoEnzyme() +end + @testset "inference RepGradELBO VILocationScale Bijectors" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 791bcbb3..e8f4da4e 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -5,12 +5,16 @@ const interface_ad_backends = Dict( :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), ) + if @isdefined(Tapir) interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false) end +if @isdefined(Enzyme) + interface_ad_backends[:Enzyme] = AutoEnzyme() +end + @testset "ad" begin @testset "$(adname)" for (adname, adtype) in interface_ad_backends D = 10 diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 00eb2d37..baf1499a 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -35,14 +35,14 @@ end @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats ad_backends = [ - ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(), - ADTypes.AutoZygote(), - ADTypes.AutoEnzyme(), + ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() ] if @isdefined(Tapir) push!(ad_backends, AutoTapir(; safe_mode=false)) end + if @isdefined(Enzyme) + push!(ad_backends, AutoEnzyme()) + end @testset for ad in ad_backends q_true = MeanFieldGaussian( diff --git a/test/runtests.jl b/test/runtests.jl index c29305e0..43958e8e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,11 +21,12 @@ using DistributionsAD using LogDensityProblems using Optimisers using ADTypes -using ForwardDiff, ReverseDiff, Zygote, Enzyme +using ForwardDiff, ReverseDiff, Zygote if VERSION >= v"1.10" Pkg.add("Tapir") using Tapir + using Enzyme end using AdvancedVI