diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 224f81d4b..dcd006c0b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,8 +19,8 @@ jobs: fail-fast: false matrix: version: - - '1.7' - - '1.10' + - 'lts' + - '1' os: - ubuntu-latest - macOS-latest @@ -29,7 +29,7 @@ jobs: - x64 steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml new file mode 100644 index 000000000..d3dbf1fc9 --- /dev/null +++ b/.github/workflows/Enzyme.yml @@ -0,0 +1,40 @@ +name: Enzyme +on: + push: + branches: + - master + tags: ['*'] + pull_request: + workflow_dispatch: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + env: + TEST_GROUP: Enzyme + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - 'lts' + - '1' + os: + - ubuntu-latest + - macOS-latest + - windows-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 diff --git a/Project.toml b/Project.toml index 572ea1443..5f44dadda 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] @@ -36,7 +35,7 @@ AdvancedVIEnzymeExt = "Enzyme" [compat] ADTypes = "1" Accessors = "0.1" -Bijectors = "0.13" +Bijectors = "0.13, 0.14, 0.15" ChainRulesCore = "1.16" DiffResults = "1" DifferentiationInterface = "0.6" @@ -45,29 +44,24 @@ DocStringExtensions = "0.8, 0.9" Enzyme = "0.13" FillArrays = "1.3" ForwardDiff = "0.10" -Functors = "0.4" +Functors = "0.4, 0.5" LinearAlgebra = "1" LogDensityProblems = "2" Mooncake = "0.4" -Optimisers = "0.2.16, 0.3" +Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.6" Random = "1" Requires = "1.0" ReverseDiff = "1" -SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" Zygote = "0.6" -julia = "1.7" +julia = "1.10, 1.11.2" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/README.md b/README.md index bc1ef6bb7..15c561209 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,6 @@ a `LogDensityProblem` can be implemented as ```julia using LogDensityProblems -using SimpleUnPack struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX diff --git a/bench/Project.toml b/bench/Project.toml index e748ff059..78eecd48c 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -21,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" BenchmarkTools = "1" -Bijectors = "0.13" +Bijectors = "0.13, 0.14, 0.15" Distributions = "0.25.111" DistributionsAD = "0.6" Enzyme = "0.13.7" @@ -30,10 +30,10 @@ ForwardDiff = "0.10" InteractiveUtils = "1" LogDensityProblems = "2" Mooncake = "0.4.5" -Optimisers = "0.3" +Optimisers = "0.3, 0.4" Random = "1" ReverseDiff = "1" SimpleUnPack = "1" StableRNGs = "1" Zygote = "0.6" -julia = "1.10" +julia = "1.10, 1.11.2" diff --git a/docs/Project.toml b/docs/Project.toml index 38c051146..39ec7317c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,22 +11,20 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ADTypes = "1" AdvancedVI = "0.3" -Bijectors = "0.13.6" +Bijectors = "0.13.6, 0.14, 0.15" Distributions = "0.25" -Documenter = "0.26, 0.27" +Documenter = "1" FillArrays = "1" ForwardDiff = "0.10" LogDensityProblems = "2.1.1" -Optimisers = "0.3" +Optimisers = "0.3, 0.4" Plots = "1" QuasiMonteCarlo = "0.3" ReverseDiff = "1" -SimpleUnPack = "1" StatsFuns = "1" -julia = "1.10" +julia = "1.10, 1.11.2" diff --git a/docs/make.jl b/docs/make.jl index c70bf05fd..62d266aa4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -20,6 +20,7 @@ makedocs(; "Variational Families" => "families.md", "Optimization" => "optimization.md", ], + warnonly=[:missing_docs], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index ff404a517..59d5dc962 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -129,7 +129,6 @@ using LinearAlgebra using LogDensityProblems using Plots using Random -using SimpleUnPack using Optimisers using ADTypes, ForwardDiff @@ -143,7 +142,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -168,7 +167,7 @@ L = Diagonal(ones(d)); q0 = AdvancedVI.MeanFieldGaussian(μ, L) function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -295,7 +294,7 @@ qmcrng = SobolSample(; R=OwenScramble(; base=2, pad=32)) function Distributions.rand( rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int ) where {L,D} - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale_diag = diag(scale) unif_samples = QuasiMonteCarlo.sample(num_samples, length(q), qmcrng) @@ -337,7 +336,7 @@ savefig("advi_qmc_dist.svg") function Distributions.rand( rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int ) where {L, D} - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale_diag = diag(scale) scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location diff --git a/docs/src/examples.md b/docs/src/examples.md index 078c004ad..9ecd3d26d 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -15,7 +15,6 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows ```@example elboexample using LogDensityProblems -using SimpleUnPack struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -25,7 +24,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -59,7 +58,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat using Bijectors function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model return Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:(1 + length(μ_y))], diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 315f896eb..05fe035d4 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -24,3 +24,5 @@ PolynomialAveraging ``` [^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. +[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. +[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 4f4441a6e..d9f12b7e2 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,6 @@ module AdvancedVI -using SimpleUnPack: @unpack, @pack! using Accessors using Random diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 4193c0873..b9bc99faa 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -37,8 +37,8 @@ function (re::RestructureMeanField)(flat::AbstractVector) return MvLocationScale(location, scale, re.model.dist) end -function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} - @unpack location, scale, dist = q +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E} + (; location, scale, dist) = q flat = vcat(location, diag(scale)) return flat, RestructureMeanField(q) end @@ -51,19 +51,19 @@ Base.size(q::MvLocationScale) = size(q.location) Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) # `convert` is necessary because `entropy` is not type stable upstream return n_dims * convert(eltype(location), entropy(dist)) + logdet(scale) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q return sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions.rand(q::MvLocationScale) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) return scale * rand(dist, n_dims) + location end @@ -71,7 +71,7 @@ end function Distributions.rand( rng::AbstractRNG, q::MvLocationScale{S,D,L}, num_samples::Int ) where {S,D,L} - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) return scale * rand(rng, dist, n_dims, num_samples) .+ location end @@ -80,7 +80,7 @@ end function Distributions.rand( rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int ) where {L,D} - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale_diag = diag(scale) return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location @@ -89,14 +89,14 @@ end function Distributions._rand!( rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real} ) - @unpack location, scale, dist = q + (; location, scale, dist) = q rand!(rng, dist, x) x[:] = scale * x return x .+= location end function Distributions.mean(q::MvLocationScale) - @unpack location, scale = q + (; location, scale) = q return location + scale * Fill(mean(q.dist), length(location)) end diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 0e3ed4c6f..ba81862ba 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -33,7 +33,7 @@ Base.size(q::MvLocationScaleLowRank) = size(q.location) Base.eltype(::Type{<:MvLocationScaleLowRank{D,L,SD,SF}}) where {D,L,SD,SF} = eltype(L) function StatsBase.entropy(q::MvLocationScaleLowRank) - @unpack location, scale_diag, scale_factors, dist = q + (; location, scale_diag, scale_factors, dist) = q n_dims = length(location) scale_diag2 = scale_diag .* scale_diag UtDinvU = Hermitian(scale_factors' * (scale_factors ./ scale_diag2)) @@ -44,7 +44,7 @@ end function Distributions.logpdf( q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false ) - @unpack location, scale_diag, scale_factors, dist = q + (; location, scale_diag, scale_factors, dist) = q μ_base = mean(dist) n_dims = length(location) @@ -67,7 +67,7 @@ function Distributions.logpdf( end function Distributions.rand(q::MvLocationScaleLowRank) - @unpack location, scale_diag, scale_factors, dist = q + (; location, scale_diag, scale_factors, dist) = q n_dims = length(location) n_factors = size(scale_factors, 2) u_diag = rand(dist, n_dims) @@ -78,7 +78,7 @@ end function Distributions.rand( rng::AbstractRNG, q::MvLocationScaleLowRank, num_samples::Int ) - @unpack location, scale_diag, scale_factors, dist = q + (; location, scale_diag, scale_factors, dist) = q n_dims = length(location) n_factors = size(scale_factors, 2) u_diag = rand(rng, dist, n_dims, num_samples) @@ -89,7 +89,7 @@ end function Distributions._rand!( rng::AbstractRNG, q::MvLocationScaleLowRank, x::AbstractVecOrMat{<:Real} ) - @unpack location, scale_diag, scale_factors, dist = q + (; location, scale_diag, scale_factors, dist) = q rand!(rng, dist, x) x[:] = scale_diag .* x @@ -101,7 +101,7 @@ function Distributions._rand!( end function Distributions.mean(q::MvLocationScaleLowRank) - @unpack location, scale_diag, scale_factors = q + (; location, scale_diag, scale_factors) = q μ = mean(q.dist) return location + scale_diag .* Fill(μ, length(scale_diag)) + @@ -109,14 +109,14 @@ function Distributions.mean(q::MvLocationScaleLowRank) end function Distributions.var(q::MvLocationScaleLowRank) - @unpack scale_diag, scale_factors = q + (; scale_diag, scale_factors) = q σ2 = var(q.dist) return σ2 * (scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1]) end function Distributions.cov(q::MvLocationScaleLowRank) - @unpack scale_diag, scale_factors = q + (; scale_diag, scale_factors) = q σ2 = var(q.dist) return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors') end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index fa34022af..210b49ca9 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -37,10 +37,3 @@ function estimate_entropy( -logpdf(q, mc_sample) end end - -function estimate_entropy_maybe_stl( - entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop -) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - return estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index b8bf63fa8..d8079c2b5 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO) return print(io, ")") end +function estimate_entropy_maybe_stl( + entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop +) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + return estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + function estimate_energy_with_samples(prob, samples) return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end @@ -85,9 +92,27 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) end -function estimate_repgradelbo_ad_forward(params′, aux) - @unpack rng, obj, problem, adtype, restructure, q_stop = aux - q = restructure_ad_forward(adtype, restructure, params′) +""" + estimate_repgradelbo_ad_forward(params, aux) + +AD-guaranteed forward path of the reparameterization gradient objective. + +# Arguments +- `params`: Variational parameters. +- `aux`: Auxiliary information excluded from the AD path. + +# Auxiliary Information +`aux` should containt the following entries: +- `rng`: Random number generator. +- `obj`: The `RepGradELBO` objective. +- `problem`: The target `LogDensityProblem`. +- `adtype`: The `ADType` used for differentiating the forward path. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +- `q_stop`: A copy of `restructure(params)` with its gradient "stopped" (excluded from the AD path). +""" +function estimate_repgradelbo_ad_forward(params, aux) + (; rng, obj, problem, adtype, restructure, q_stop) = aux + q = restructure_ad_forward(adtype, restructure, params) samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(problem, samples) elbo = energy + entropy diff --git a/src/objectives/elbo/scoregradelbo.jl b/src/objectives/elbo/scoregradelbo.jl index 053c6b3f4..8b96fa1c0 100644 --- a/src/objectives/elbo/scoregradelbo.jl +++ b/src/objectives/elbo/scoregradelbo.jl @@ -1,113 +1,63 @@ + """ ScoreGradELBO(n_samples; kwargs...) -Evidence lower-bound objective computed with score function gradients. -```math -\\begin{aligned} -\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right) -&\\= -\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ - \\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z) -\\right] -+ \\mathbb{H}\\left(q_{\\lambda}\\right), -\\end{aligned} -``` - -To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective. - -```math -\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ - \\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right) -\\right] -``` +Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate. # Arguments -- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. - -# Keyword Arguments -- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`) -- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`) -- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`) +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective. # Requirements - The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`. - `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend. - The target distribution and the variational approximation have the same support. - -Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ -struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <: - AdvancedVI.AbstractVariationalObjective - entropy::EntropyEst +struct ScoreGradELBO <: AbstractVariationalObjective n_samples::Int - baseline_window_size::Int - baseline_history::Vector{Float64} -end - -function ScoreGradELBO( - n_samples::Int; - entropy::AbstractEntropyEstimator=ClosedFormEntropy(), - baseline_window_size::Int=10, - baseline_history::Vector{Float64}=Float64[], -) - return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history) end function Base.show(io::IO, obj::ScoreGradELBO) - print(io, "ScoreGradELBO(entropy=") - print(io, obj.entropy) - print(io, ", n_samples=") + print(io, "ScoreGradELBO(n_samples=") print(io, obj.n_samples) - print(io, ", baseline_window_size=") - print(io, obj.baseline_window_size) return print(io, ")") end -function compute_control_variate_baseline(history, window_size) - if length(history) == 0 - return 1.0 - end - min_index = max(1, length(history) - window_size) - return mean(history[min_index:end]) -end - -function estimate_energy_with_samples( - prob, samples_stop, samples_logprob, samples_logprob_stop, baseline -) - fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop)) - fv_mean = mean(fv) - score_grad = mean(@. samples_logprob * (fv - baseline)) - score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline)) - return fv_mean + (score_grad - score_grad_stop) -end - function estimate_objective( rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy) - energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) - return mean(energy) + entropy + samples = rand(rng, q, n_samples) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) + return mean(ℓπ - ℓq) end function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples) return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) end -function estimate_scoregradelbo_ad_forward(params′, aux) - @unpack rng, obj, problem, adtype, restructure, q_stop = aux - baseline = compute_control_variate_baseline( - obj.baseline_history, obj.baseline_window_size - ) - q = restructure_ad_forward(adtype, restructure, params′) - samples_stop = rand(rng, q_stop, obj.n_samples) - entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop) - samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop)) - samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop)) - energy = estimate_energy_with_samples( - problem, samples_stop, samples_logprob, samples_logprob_stop, baseline - ) - elbo = energy + entropy - return -elbo +""" + estimate_scoregradelbo_ad_forward(params, aux) + +AD-guaranteed forward path of the score gradient objective. + +# Arguments +- `params`: Variational parameters. +- `aux`: Auxiliary information excluded from the AD path. + +# Auxiliary Information +`aux` should containt the following entries: +- `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path). +- `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`. +- `adtype`: The `ADType` used for differentiating the forward path. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +""" +function estimate_scoregradelbo_ad_forward(params, aux) + (; samples_stop, logprob_stop, adtype, restructure) = aux + q = restructure_ad_forward(adtype, restructure, params) + ℓπ = logprob_stop + ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop)) + f = ℓq - ℓπ + return (mean(abs2, f) - mean(f)^2) / 2 end function AdvancedVI.estimate_gradient!( @@ -120,20 +70,15 @@ function AdvancedVI.estimate_gradient!( restructure, state, ) - q_stop = restructure(params) - aux = ( - rng=rng, - adtype=adtype, - obj=obj, - problem=prob, - restructure=restructure, - q_stop=q_stop, - ) + q = restructure(params) + samples = rand(rng, q, obj.n_samples) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) AdvancedVI.value_and_gradient!( adtype, estimate_scoregradelbo_ad_forward, params, aux, out ) - nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - push!(obj.baseline_history, -nelbo) + ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) + elbo = mean(ℓπ - ℓq) + stat = (elbo=elbo,) return out, nothing, stat end diff --git a/src/optimization/rules.jl b/src/optimization/rules.jl index 7bb65e86d..6bdd82777 100644 --- a/src/optimization/rules.jl +++ b/src/optimization/rules.jl @@ -8,8 +8,6 @@ It's only parameter is the initial guess of the Euclidean distance to the optimu # Parameters - `repsilon`: Initial guess of the Euclidean distance between the initial point and the optimum. (default value: `1e-6`) - -[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. """ Optimisers.@def struct DoWG <: Optimisers.AbstractRule repsilon = 1e-6 @@ -37,8 +35,6 @@ The original paper recommends \$ 10^{-4} ( 1 + \\lVert \\lambda_0 \\rVert ) \$, # Parameters - `repsilon`: Initial guess of the Euclidean distance between the initial point and the optimum. (default value: `1e-6`) - -[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. """ Optimisers.@def struct DoG <: Optimisers.AbstractRule repsilon = 1e-6 diff --git a/test/Project.toml b/test/Project.toml index bbf9c4c61..3a71865b4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,17 +5,18 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -25,25 +26,27 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1, 1" -Bijectors = "0.13" +Bijectors = "0.13, 0.14, 0.15" DiffResults = "1" DifferentiationInterface = "0.6" Distributions = "0.25.111" DistributionsAD = "0.6.45" +Enzyme = "0.13, 0.14, 0.15" FillArrays = "1.6.1" ForwardDiff = "0.10.36" -Functors = "0.4.5" +Functors = "0.4.5, 0.5" LinearAlgebra = "1" LogDensityProblems = "2.1.1" -Optimisers = "0.2.16, 0.3" +Mooncake = "0.4" +Optimisers = "0.2.16, 0.3, 0.4" PDMats = "0.11.7" +Pkg = "1" Random = "1" ReverseDiff = "1.15.1" -SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" StatsBase = "0.34" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" -julia = "1.6" +julia = "1.10, 1.11.2" diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index 92d15f6ab..e9c0bda95 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -78,7 +78,7 @@ @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) - @test z_sample_ref == rand(StableRNG(1), q) + @test z_sample_ref ≈ rand(StableRNG(1), q) end @testset "rand batch" begin @@ -104,7 +104,7 @@ end z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) - @test z_samples == z_samples_ret + @test z_samples ≈ z_samples_ret @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) @@ -124,7 +124,7 @@ @testset "rand! AbstractMatrix" begin z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) - @test z_samples == z_samples_ret + @test z_samples ≈ z_samples_ret @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) @@ -138,7 +138,7 @@ z_samples = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples) - @test z_samples_ref == z_samples + @test z_samples_ref ≈ z_samples end end end diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index fbe70ae9d..286011ad6 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,17 +1,12 @@ -AD_distributionsad = Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment - :Zygote => AutoZygote(), -) - -if @isdefined(Mooncake) - AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) -end - -if @isdefined(Enzyme) - AD_distributionsad[:Enzyme] = AutoEnzyme(; - mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const +AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end @@ -19,19 +14,18 @@ end @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => - RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + RepGradELBO(10; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_distributionsad + (adbackname, adtype) in AD_repgradelbo_distributionsad seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 η = 1e-3 diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index db89f3c42..a3aee8d03 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,38 +1,32 @@ -AD_locationscale = Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), -) - -if @isdefined(Mooncake) - AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) -end - -if @isdefined(Enzyme) - AD_locationscale[:Enzyme] = AutoEnzyme(; - mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const +AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end -@testset "inference ScoreGradELBO VILocationScale" begin +@testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), - n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => - RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + RepGradELBO(10; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_locationscale + (adbackname, adtype) in AD_repgradelbo_locationscale seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 η = 1e-3 diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 338f56151..9594016ce 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,17 +1,12 @@ -AD_locationscale_bijectors = Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), -) - -if @isdefined(Mooncake) - AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) -end - -if @isdefined(Enzyme) - AD_locationscale_bijectors[:Enzyme] = AutoEnzyme(; - mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const +AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end @@ -20,19 +15,18 @@ end [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), - n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => - RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + RepGradELBO(10; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_locationscale_bijectors + (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 η = 1e-3 diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 9a621b402..c7aa9a44c 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -1,38 +1,30 @@ -AD_scoregradelbo_distributionsad = Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment - :Zygote => AutoZygote(), -) - -if @isdefined(Mooncake) - AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) +AD_scoregradelbo_distributionsad = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + #:Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) end -#if @isdefined(Enzyme) -# AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme() -#end - @testset "inference ScoreGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - n_montecarlo in [1, 10], - (objname, objective) in Dict( - :ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), - :ScoreGradELBOStickingTheLanding => - ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), - ), + (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), (adbackname, adtype) in AD_scoregradelbo_distributionsad seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 - η = 1e-5 + η = 1e-4 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index 60623d6fc..4b822e8cd 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -1,17 +1,12 @@ -AD_scoregradelbo_locationscale = Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), -) - -if @isdefined(Mooncake) - AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) -end - -if @isdefined(Enzyme) - AD_scoregradelbo_locationscale[:Enzyme] = AutoEnzyme(; - mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const +AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end @@ -20,22 +15,17 @@ end [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), - n_montecarlo in [1, 10], - (objname, objective) in Dict( - :ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), - :ScoreGradELBOStickingTheLanding => - ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), - ), - (adbackname, adtype) in AD_locationscale + (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), + (adbackname, adtype) in AD_scoregradelbo_locationscale seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 - η = 1e-5 + η = 1e-4 opt = ProjectScale(Optimisers.Descent(realtype(η))) # For small enough η, the error of SGD, Δλ, is bounded as diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index 7d638ff3c..8fa5cac0b 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -1,16 +1,13 @@ -AD_scoregradelbo_locationscale_bijectors = Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), -) - -#if @isdefined(Tapir) -# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) -#end - -if @isdefined(Enzyme) - AD_scoregradelbo_locationscale_bijectors[:Enzyme] = AutoEnzyme() +AD_scoregradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) end @testset "inference ScoreGradELBO VILocationScale Bijectors" begin @@ -18,22 +15,17 @@ end [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), - n_montecarlo in [1, 10], - (objname, objective) in Dict( - #:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet. - :ScoreGradELBOStickingTheLanding => - ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), - ), + (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), (adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 - η = 1e-5 + η = 1e-4 opt = ProjectScale(Optimisers.Descent(realtype(η))) b = Bijectors.bijector(model) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index e23aec580..5f8e8f0fb 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -1,22 +1,19 @@ using Test -const interface_ad_backends = Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), -) - -if @isdefined(Mooncake) - interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) -end - -if @isdefined(Enzyme) - interface_ad_backends[:Enzyme] = AutoEnzyme() +AD_interface = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) end @testset "ad" begin - @testset "$(adname)" for (adname, adtype) in interface_ad_backends + @testset "$(adname)" for (adname, adtype) in AD_interface D = 10 A = randn(D, D) λ = randn(D) diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index d7294ccaf..c51e39cdc 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -8,7 +8,7 @@ using Test T = 1000 modelstats = normal_meanfield(rng, Float64) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats # Global Test Configurations q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index be835e203..da3a59acc 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -1,5 +1,14 @@ -using Test +AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" + [AutoEnzyme()] +else + [ + AutoForwardDiff(), + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] +end @testset "interface RepGradELBO" begin seed = (0x38bef07cf9cc549d) @@ -7,9 +16,26 @@ using Test modelstats = normal_meanfield(rng, Float64) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) + + @testset "basic" begin + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] + obj = RepGradELBO(n_montecarlo) + _, _, stats, _ = optimize( + rng, + model, + obj, + q0, + 10; + optimizer=Descent(1e-5), + show_progress=false, + adtype=adtype, + ) + @assert isfinite(last(stats).elbo) + end + end obj = RepGradELBO(10) rng = StableRNG(seed) @@ -32,29 +58,14 @@ end rng = StableRNG(seed) modelstats = normal_meanfield(rng, Float64) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - ad_backends = [ - ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() - ] - if @isdefined(Mooncake) - push!(ad_backends, AutoMooncake(; config=Mooncake.Config())) - end - if @isdefined(Enzyme) - push!( - ad_backends, - AutoEnzyme(; - mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const - ), - ) - end + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - @testset for adtype in ad_backends + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) params, re = Optimisers.destructure(q_true) - obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = ( diff --git a/test/interface/scoregradelbo.jl b/test/interface/scoregradelbo.jl index 8a6ebb14f..f368626e7 100644 --- a/test/interface/scoregradelbo.jl +++ b/test/interface/scoregradelbo.jl @@ -1,5 +1,14 @@ -using Test +AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" + [AutoEnzyme()] +else + [ + AutoForwardDiff(), + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] +end @testset "interface ScoreGradELBO" begin seed = (0x38bef07cf9cc549d) @@ -7,9 +16,26 @@ using Test modelstats = normal_meanfield(rng, Float64) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + + q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) + + @testset "basic" begin + @testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10] + obj = ScoreGradELBO(n_montecarlo) + _, _, stats, _ = optimize( + rng, + model, + obj, + q0, + 10; + optimizer=Descent(1e-5), + show_progress=false, + adtype=adtype, + ) + @assert isfinite(last(stats).elbo) + end + end obj = ScoreGradELBO(10) rng = StableRNG(seed) diff --git a/test/models/normal.jl b/test/models/normal.jl index 9fc6ae38a..5826547df 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -5,7 +5,7 @@ struct TestNormal{M,S} end function LogDensityProblems.logdensity(model::TestNormal, θ) - @unpack μ, Σ = model + (; μ, Σ) = model return logpdf(MvNormal(μ, Σ), θ) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 176aab2f5..00949bc1b 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) end function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model return Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:(1 + length(μ_y))], diff --git a/test/runtests.jl b/test/runtests.jl index 7c0e3129e..6a0bf7af2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,6 @@ using Optimisers using PDMats using Pkg using Random, StableRNGs -using SimpleUnPack: @unpack using Statistics using StatsBase @@ -22,18 +21,15 @@ using DistributionsAD @functor TuringDiagMvNormal using ADTypes -using ForwardDiff, ReverseDiff, Zygote - -if VERSION >= v"1.10" - Pkg.add("Mooncake") - Pkg.add("Enzyme") - using Mooncake - using Enzyme -end +using ForwardDiff, ReverseDiff, Zygote, Mooncake using AdvancedVI -const GROUP = get(ENV, "GROUP", "All") +const TEST_GROUP = get(ENV, "TEST_GROUP", "All") + +if TEST_GROUP == "Enzyme" + using Enzyme +end # Models for Inference Tests struct TestModel{M,L,S,SC} @@ -47,24 +43,28 @@ end include("models/normal.jl") include("models/normallognormal.jl") -# Tests -if GROUP == "All" || GROUP == "Interface" - include("interface/ad.jl") +if TEST_GROUP == "All" || TEST_GROUP == "Interface" + # Interface tests that do not involve testing on Enzyme include("interface/optimize.jl") - include("interface/repgradelbo.jl") - include("interface/scoregradelbo.jl") include("interface/rules.jl") include("interface/averaging.jl") + include("interface/scoregradelbo.jl") +end + +if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme" + # Interface tests that involve testing on Enzyme + include("interface/ad.jl") + include("interface/repgradelbo.jl") end -if GROUP == "All" || GROUP == "Families" +if TEST_GROUP == "All" || TEST_GROUP == "Families" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") end const PROGRESS = haskey(ENV, "PROGRESS") -if GROUP == "All" || GROUP == "Inference" +if TEST_GROUP == "All" || TEST_GROUP == "Inference" || TEST_GROUP == "Enzyme" include("inference/repgradelbo_distributionsad.jl") include("inference/repgradelbo_locationscale.jl") include("inference/repgradelbo_locationscale_bijectors.jl")