From b49cf3e8cf2162706824735f0662559d6f838d55 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 14 Mar 2023 19:13:41 +0000 Subject: [PATCH 001/206] refactor ADVI, change gradient operation interface --- Project.toml | 1 + src/AdvancedVI.jl | 181 ++++++++++++++--------------------------- src/advi.jl | 47 ----------- src/estimators/advi.jl | 29 +++++++ src/utils.jl | 15 ++++ 5 files changed, 107 insertions(+), 166 deletions(-) create mode 100644 src/estimators/advi.jl diff --git a/Project.toml b/Project.toml index 28adc66a..71a2cbdc 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e203a13c..d42683d0 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -33,20 +33,12 @@ function __init__() export ZygoteAD function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ZygoteAD}, - q, - model, - θ::AbstractVector{<:Real}, + f::Function, + ::Type{<:ZygoteAD}, + λ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, - args... ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - y, back = Zygote.pullback(f, θ) + y, back = Zygote.pullback(f, λ) dy = first(back(1.0)) DiffResults.value!(out, y) DiffResults.gradient!(out, dy) @@ -58,21 +50,13 @@ function __init__() export ReverseDiffAD function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, - q, - model, - θ::AbstractVector{<:Real}, + f::Function, + ::Type{<:ReverseDiffAD}, + λ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, - args... ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - tp = AdvancedVI.tape(f, θ) - ReverseDiff.gradient!(out, tp, θ) + tp = AdvancedVI.tape(f, λ) + ReverseDiff.gradient!(out, tp, λ) return out end end @@ -81,26 +65,18 @@ function __init__() export EnzymeAD function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.EnzymeAD}, - q, - model, - θ::AbstractVector{<:Real}, + f::Function, + ::Type{<:EnzymeAD}, + λ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, - args... ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end # Use `Enzyme.ReverseWithPrimal` once it is released: # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(θ) + y = f(λ) DiffResults.value!(out, y) dy = DiffResults.gradient(out) fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) return out end end @@ -109,16 +85,8 @@ end export vi, ADVI, - ELBO, - elbo, TruncatedADAGrad, - DecayedADAGrad, - VariationalInference - -abstract type VariationalInference{AD} end - -getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD) -getADtype(::VariationalInference{AD}) where AD = AD + DecayedADAGrad abstract type VariationalObjective end @@ -126,13 +94,11 @@ const VariationalPosterior = Distribution{Multivariate, Continuous} """ - grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) + grad!(f, λ, out) -Computes the gradients used in `optimize!`. Default implementation is provided for +Computes the gradients of the objective f. Default implementation is provided for `VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. This implicitly also gives a default implementation of `optimize!`. - -Variance reduction techniques, e.g. control variates, should be implemented in this function. """ function grad! end @@ -157,51 +123,36 @@ function update end # default implementations function grad!( - vo, - alg::VariationalInference{<:ForwardDiffAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... + f::Function, + adtype::Type{<:ForwardDiffAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - f(θ_) = if (q isa Distribution) - - vo(alg, update(q, θ_), model, args...) - else - - vo(alg, q(θ_), model, args...) - end - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(typeof(alg)) + chunk_size = getchunksize(adtype) config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, θ) + ForwardDiff.GradientConfig(f, λ) else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) end - ForwardDiff.gradient!(out, f, θ, config) + ForwardDiff.gradient!(out, f, λ, config) end function grad!( - vo, - alg::VariationalInference{<:TrackerAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... + f::Function, + ::Type{<:TrackerAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - θ_tracked = Tracker.param(θ) - y = if (q isa Distribution) - - vo(alg, update(q, θ_tracked), model, args...) - else - - vo(alg, q(θ_tracked), model, args...) - end + λ_tracked = Tracker.param(λ) + y = f(λ_tracked) Tracker.back!(y, 1.0) DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(θ_tracked)) + DiffResults.gradient!(out, Tracker.grad(λ_tracked)) end +abstract type AbstractGradientEstimator end """ optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) @@ -210,61 +161,53 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer the steps. """ function optimize!( - vo, - alg::VariationalInference, - q, - model, - θ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad() + grad_estimator::AbstractGradientEstimator, + rebuild::Function, + ℓπ::Function, + n_max_iter::Int, + λ::AbstractVector{<:Real}; + optimizer = TruncatedADAGrad(), + rng = Random.GLOBAL_RNG ) - # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) - samples_per_step = alg.samples_per_step - max_iters = alg.max_iters - - num_params = length(θ) + obj_name = objective(grad_estimator) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) + if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) + @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) end - diff_result = DiffResults.GradientResult(θ) + grad_buf = DiffResults.GradientResult(λ) i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 - end + prog = ProgressMeter.Progress( + n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) - - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ + time_elapsed = @elapsed begin + for i = 1:n_max_iter + stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf) + + # apply update rule + Δλ = DiffResults.gradient(grad_buf) + Δλ = apply!(optimizer, λ, Δλ) + @. λ = λ - Δλ + + stat′ = (Δλ=norm(Δλ),) + stats = merge(stats, stat′) - AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) - - i += 1 + AdvancedVI.DEBUG && @debug "Step $i" stats... + pm_next!(prog, stats) + end end - - return θ + return λ end # objectives -include("objectives.jl") +include("estimators/advi.jl") # optimisers include("optimisers.jl") -# VI algorithms -include("advi.jl") - end # module diff --git a/src/advi.jl b/src/advi.jl index 7f9e7346..be9823db 100644 --- a/src/advi.jl +++ b/src/advi.jl @@ -50,50 +50,3 @@ function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = Truncate return θ end -# WITHOUT updating parameters inside ELBO -function (elbo::ELBO)( - rng::Random.AbstractRNG, - alg::ADVI, - q::VariationalPosterior, - logπ::Function, - num_samples -) - # 𝔼_q(z)[log p(xᵢ, z)] - # = ∫ log p(xᵢ, z) q(z) dz - # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) - # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] - - # 𝔼_q(z)[log q(z)] - # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ) [log q(f(ϕ))] - # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] - # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - - # Finally, the ELBO is given by - # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) - - # If f: supp(p(z | x)) → ℝ then - # ELBO = 𝔼[log p(x, z) - log q(z)] - # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) - # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - - # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - z, logjac = rand_and_logjac(rng, q) - res = (logπ(z) + logjac) / num_samples - - if q isa TransformedDistribution - res += entropy(q.dist) - else - res += entropy(q) - end - - for i = 2:num_samples - z, logjac = rand_and_logjac(rng, q) - res += (logπ(z) + logjac) / num_samples - end - - return res -end diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl new file mode 100644 index 00000000..c5a83957 --- /dev/null +++ b/src/estimators/advi.jl @@ -0,0 +1,29 @@ + +struct ADVI <: AbstractGradientEstimator + n_samples::Int +end + +objective(::ADVI) = "ELBO" + +function estimate_gradient!( + rng::Random.AbstractRNG, + estimator::ADVI, + λ::Vector{<:Real}, + rebuild::Function, + logπ::Function, + out::DiffResults.MutableDiffResult) + + n_samples = estimator.n_samples + + grad!(ADBackend(), λ, out) do λ′ + q = rebuild(λ′) + zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples) + + elbo = mapreduce(+, eachcol(zs)) do zᵢ + (logπ(zᵢ) + ∑logjac) + end / n_samples + -elbo + end + nelbo = DiffResults.value(out) + (elbo=-nelbo,) +end diff --git a/src/utils.jl b/src/utils.jl index bb4c1f18..87cc0856 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,18 @@ function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDis y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) return y, logjac end + +function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int) + x = rand(rng, dist, n_samples) + return x, zero(eltype(x)) +end + +function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int) + x = rand(rng, dist.dist, n_samples) + y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) + return y, logjac +end + +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end From 88e0b79758c2f207b9d3c7120b469af837049fec Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 14 Mar 2023 19:56:47 +0000 Subject: [PATCH 002/206] remove unused file, remove unused dependency --- Project.toml | 1 - src/objectives.jl | 7 ------- 2 files changed, 8 deletions(-) delete mode 100644 src/objectives.jl diff --git a/Project.toml b/Project.toml index 71a2cbdc..28adc66a 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/objectives.jl b/src/objectives.jl deleted file mode 100644 index 5a6b61b0..00000000 --- a/src/objectives.jl +++ /dev/null @@ -1,7 +0,0 @@ -struct ELBO <: VariationalObjective end - -function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) - return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...) -end - -const elbo = ELBO() From c2fb3f8d08c15b16fa2e84a359b0d9bda3bf45b2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 15 Mar 2023 18:53:50 +0000 Subject: [PATCH 003/206] fix ADVI elbo computation more efficiently --- src/estimators/advi.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index c5a83957..44f65909 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -17,13 +17,17 @@ function estimate_gradient!( grad!(ADBackend(), λ, out) do λ′ q = rebuild(λ′) - zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples) - - elbo = mapreduce(+, eachcol(zs)) do zᵢ - (logπ(zᵢ) + ∑logjac) - end / n_samples + zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples) + + 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ + logπ(zᵢ) / n_samples + end + 𝔼logdetjac = ∑logdetjac/n_samples + + elbo = 𝔼logπ + 𝔼logdetjac -elbo end nelbo = DiffResults.value(out) (elbo=-nelbo,) end + From 83161fdf7fd18d9f686483da38174148ad305c9f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 15 Mar 2023 19:20:51 +0000 Subject: [PATCH 004/206] fix missing entropy regularization term --- src/estimators/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index 44f65909..ad45efbb 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -24,7 +24,7 @@ function estimate_gradient!( end 𝔼logdetjac = ∑logdetjac/n_samples - elbo = 𝔼logπ + 𝔼logdetjac + elbo = 𝔼logπ + 𝔼logdetjac + entropy(q) -elbo end nelbo = DiffResults.value(out) From efa810687738f4d297ff8b25aaadf28e37ba2080 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 18 Mar 2023 01:04:02 +0000 Subject: [PATCH 005/206] add LogDensityProblem interface --- Project.toml | 1 + src/AdvancedVI.jl | 5 +++-- src/estimators/advi.jl | 19 ++++++++++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 28adc66a..6ad4b689 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d42683d0..e1ac752f 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -7,6 +7,8 @@ using DocStringExtensions using ProgressMeter, LinearAlgebra +using LogDensityProblems + using ForwardDiff using Tracker @@ -163,7 +165,6 @@ the steps. function optimize!( grad_estimator::AbstractGradientEstimator, rebuild::Function, - ℓπ::Function, n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), @@ -187,7 +188,7 @@ function optimize!( # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf) + stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) # apply update rule Δλ = DiffResults.gradient(grad_buf) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index ad45efbb..5a8652b6 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -1,8 +1,22 @@ -struct ADVI <: AbstractGradientEstimator +struct ADVI{Tlogπ} <: AbstractGradientEstimator + ℓπ::Tlogπ n_samples::Int end +function ADVI(ℓπ, n_samples; kwargs...) + # ADVI requires gradients of log-likelihood + cap = LogDensityProblems.capabilities(ℓπ) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ADVI(Base.Fix1(LogDensityProblems.logdensity, ℓπ), n_samples) +end + objective(::ADVI) = "ELBO" function estimate_gradient!( @@ -10,7 +24,6 @@ function estimate_gradient!( estimator::ADVI, λ::Vector{<:Real}, rebuild::Function, - logπ::Function, out::DiffResults.MutableDiffResult) n_samples = estimator.n_samples @@ -20,7 +33,7 @@ function estimate_gradient!( zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples) 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ - logπ(zᵢ) / n_samples + estimator.ℓπ(zᵢ) / n_samples end 𝔼logdetjac = ∑logdetjac/n_samples From 4ae2fbfa832662b5adaa7e3d423cb312cb87b4c9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 18 Mar 2023 02:22:32 +0000 Subject: [PATCH 006/206] refactor use bijectors directly instead of transformed distributions This is to avoid having to reconstruct transformed distributions all the time. The direct use of bijectors also avoids going through lots of abstraction layers that could break. Instead, transformed distributions could be constructed only once when returing the VI result. --- src/estimators/advi.jl | 43 ++++++++++++++++++++++++++---------------- src/utils.jl | 30 ----------------------------- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index 5a8652b6..9784e924 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -1,22 +1,32 @@ -struct ADVI{Tlogπ} <: AbstractGradientEstimator +struct ADVI{Tlogπ, B <: Union{Function, Bijectors.Inverse{<:Bijectors.Bijector}}} <: AbstractGradientEstimator + # Automatic differentiation variational inference + # + # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). + # Automatic differentiation variational inference. + # Journal of machine learning research. + ℓπ::Tlogπ + b⁻¹::B n_samples::Int -end -function ADVI(ℓπ, n_samples; kwargs...) - # ADVI requires gradients of log-likelihood - cap = LogDensityProblems.capabilities(ℓπ) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) + function ADVI(prob, b⁻¹::B, n_samples; kwargs...) where {B <: Bijectors.Inverse{<:Bijectors.Bijector}} + # Could check whether the support of b⁻¹ and ℓπ match + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹, n_samples) end - ADVI(Base.Fix1(LogDensityProblems.logdensity, ℓπ), n_samples) end +ADVI(prob, n_samples; kwargs...) = ADVI(prob, identity, n_samples; kwargs...) + objective(::ADVI) = "ELBO" function estimate_gradient!( @@ -29,18 +39,19 @@ function estimate_gradient!( n_samples = estimator.n_samples grad!(ADBackend(), λ, out) do λ′ - q = rebuild(λ′) - zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples) + q_η = rebuild(λ′) + ηs = rand(rng, q_η, estimator.n_samples) + + zs, ∑logdetjac = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηs) 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ estimator.ℓπ(zᵢ) / n_samples end 𝔼logdetjac = ∑logdetjac/n_samples - elbo = 𝔼logπ + 𝔼logdetjac + entropy(q) + elbo = 𝔼logπ + 𝔼logdetjac + entropy(q_η) -elbo end nelbo = DiffResults.value(out) (elbo=-nelbo,) end - diff --git a/src/utils.jl b/src/utils.jl index 87cc0856..e69de29b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,30 +0,0 @@ -using Distributions - -using Bijectors: Bijectors - - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution) - x = rand(rng, dist) - return x, zero(eltype(x)) -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution) - x = rand(rng, dist.dist) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int) - x = rand(rng, dist, n_samples) - return x, zero(eltype(x)) -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int) - x = rand(rng, dist.dist, n_samples) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac -end - -function pm_next!(pm, stats::NamedTuple) - ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) -end From 1cadb51a011eeaf0b7d3e05aee7e45494bc2439a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 00:54:02 +0100 Subject: [PATCH 007/206] fix type restrictions --- src/estimators/advi.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index 9784e924..b4b3a9d0 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -1,5 +1,5 @@ -struct ADVI{Tlogπ, B <: Union{Function, Bijectors.Inverse{<:Bijectors.Bijector}}} <: AbstractGradientEstimator +struct ADVI{Tlogπ, B} <: AbstractGradientEstimator # Automatic differentiation variational inference # # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). @@ -33,7 +33,7 @@ function estimate_gradient!( rng::Random.AbstractRNG, estimator::ADVI, λ::Vector{<:Real}, - rebuild::Function, + rebuild, out::DiffResults.MutableDiffResult) n_samples = estimator.n_samples From 3474e8d2c97032f7a384d3b88cb7cc47bdae12f3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 00:54:23 +0100 Subject: [PATCH 008/206] remove unused file --- src/advi.jl | 52 ---------------------------------------------------- 1 file changed, 52 deletions(-) delete mode 100644 src/advi.jl diff --git a/src/advi.jl b/src/advi.jl deleted file mode 100644 index be9823db..00000000 --- a/src/advi.jl +++ /dev/null @@ -1,52 +0,0 @@ -using StatsFuns -using DistributionsAD -using Bijectors -using Bijectors: TransformedDistribution - - -""" -$(TYPEDEF) - -Automatic Differentiation Variational Inference (ADVI) with automatic differentiation -backend `AD`. - -# Fields - -$(TYPEDFIELDS) -""" -struct ADVI{AD} <: VariationalInference{AD} - "Number of samples used to estimate the ELBO in each optimization step." - samples_per_step::Int - "Maximum number of gradient steps." - max_iters::Int -end - -function ADVI(samples_per_step::Int=1, max_iters::Int=1000) - return ADVI{ADBackend()}(samples_per_step, max_iters) -end - -alg_str(::ADVI) = "ADVI" - -function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - # If `q` is a mean-field approx we use the specialized `update` function - if q isa Distribution - return update(q, θ) - else - # Otherwise we assume it's a mapping θ → q - return q(θ) - end -end - - -function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - - # `model` assumed to be callable z ↦ p(x, z) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - return θ -end - From 03a27679f98790f943b784d0f6282035ecdc8abe Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 03:19:03 +0100 Subject: [PATCH 009/206] fix use of with_logabsdet_jacobian --- src/estimators/advi.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index b4b3a9d0..701ec1ef 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -10,7 +10,7 @@ struct ADVI{Tlogπ, B} <: AbstractGradientEstimator b⁻¹::B n_samples::Int - function ADVI(prob, b⁻¹::B, n_samples; kwargs...) where {B <: Bijectors.Inverse{<:Bijectors.Bijector}} + function ADVI(prob, b⁻¹, n_samples; kwargs...) # Could check whether the support of b⁻¹ and ℓπ match cap = LogDensityProblems.capabilities(prob) if cap === nothing @@ -42,14 +42,12 @@ function estimate_gradient!( q_η = rebuild(λ′) ηs = rand(rng, q_η, estimator.n_samples) - zs, ∑logdetjac = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηs) - - 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ - estimator.ℓπ(zᵢ) / n_samples + 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηᵢ) + (estimator.ℓπ(zᵢ) + logdetjacᵢ) / n_samples end - 𝔼logdetjac = ∑logdetjac/n_samples - elbo = 𝔼logπ + 𝔼logdetjac + entropy(q_η) + elbo = 𝔼ℓ + entropy(q_η) -elbo end nelbo = DiffResults.value(out) From 09c44fb639864167e6548db89b7ad0196d04ddfc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 03:29:42 +0100 Subject: [PATCH 010/206] restructure project; move the main VI routine to its own file --- src/AdvancedVI.jl | 60 +++++++----------------------------------- src/vi.jl | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 50 deletions(-) create mode 100644 src/vi.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e1ac752f..d3612cb1 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -12,6 +12,13 @@ using LogDensityProblems using ForwardDiff using Tracker +using Bijectors: Bijectors + +using Distributions +using DistributionsAD + +using StatsFuns + const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") @@ -154,61 +161,14 @@ function grad!( DiffResults.gradient!(out, Tracker.grad(λ_tracked)) end +# estimators abstract type AbstractGradientEstimator end -""" - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) - -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. -""" -function optimize!( - grad_estimator::AbstractGradientEstimator, - rebuild::Function, - n_max_iter::Int, - λ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad(), - rng = Random.GLOBAL_RNG -) - obj_name = objective(grad_estimator) - - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) - end - - grad_buf = DiffResults.GradientResult(λ) - - i = 0 - prog = ProgressMeter.Progress( - n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) - - # add criterion? A running mean maybe? - time_elapsed = @elapsed begin - for i = 1:n_max_iter - stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) - - # apply update rule - Δλ = DiffResults.gradient(grad_buf) - Δλ = apply!(optimizer, λ, Δλ) - @. λ = λ - Δλ - - stat′ = (Δλ=norm(Δλ),) - stats = merge(stats, stat′) - - AdvancedVI.DEBUG && @debug "Step $i" stats... - pm_next!(prog, stats) - end - end - return λ -end - -# objectives include("estimators/advi.jl") # optimisers include("optimisers.jl") +include("vi.jl") + end # module diff --git a/src/vi.jl b/src/vi.jl new file mode 100644 index 00000000..aceb3f2d --- /dev/null +++ b/src/vi.jl @@ -0,0 +1,66 @@ + +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end + +""" + optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) + +Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute +the steps. +""" +function optimize( + grad_estimator::AbstractGradientEstimator, + rebuild::Function, + n_max_iter::Int, + λ::AbstractVector{<:Real}; + optimizer = TruncatedADAGrad(), + rng = Random.GLOBAL_RNG +) + obj_name = objective(grad_estimator) + + # TODO: really need a better way to warn the user about potentially + # not using the correct accumulator + if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) + # this message should only occurr once in the optimization process + @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) + end + + grad_buf = DiffResults.GradientResult(λ) + + i = 0 + prog = ProgressMeter.Progress( + n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) + + # add criterion? A running mean maybe? + time_elapsed = @elapsed begin + for i = 1:n_max_iter + stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) + + # apply update rule + Δλ = DiffResults.gradient(grad_buf) + Δλ = apply!(optimizer, λ, Δλ) + @. λ = λ - Δλ + + stat′ = (Δλ=norm(Δλ),) + stats = merge(stats, stat′) + + AdvancedVI.DEBUG && @debug "Step $i" stats... + pm_next!(prog, stats) + end + end + return λ +end + +# function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG) +# θ = copy(θ_init) +# optimize!(grad_estimator, rebuild, n_max_iter, λ, optimizer = optimizer, rng = rng) + +# # If `q` is a mean-field approx we use the specialized `update` function +# if q isa Distribution +# return update(q, θ) +# else +# # Otherwise we assume it's a mapping θ → q +# return q(θ) +# end +# end From b7407ceecd7f6c8e3fc7a4c443995347fd4659f5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 03:31:35 +0100 Subject: [PATCH 011/206] remove redundant import --- src/AdvancedVI.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d3612cb1..32b114ba 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -12,8 +12,6 @@ using LogDensityProblems using ForwardDiff using Tracker -using Bijectors: Bijectors - using Distributions using DistributionsAD From 40401494ef032b1c9623856ed668373b251aaccb Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 9 Jun 2023 00:51:56 +0100 Subject: [PATCH 012/206] restructure project into more modular objective estimators --- src/AdvancedVI.jl | 8 ++--- src/estimators/advi.jl | 55 ------------------------------ src/objectives/elbo/advi_energy.jl | 35 +++++++++++++++++++ src/objectives/elbo/elbo.jl | 44 ++++++++++++++++++++++++ src/objectives/elbo/entropy.jl | 18 ++++++++++ src/vi.jl | 10 +++--- 6 files changed, 105 insertions(+), 65 deletions(-) delete mode 100644 src/estimators/advi.jl create mode 100644 src/objectives/elbo/advi_energy.jl create mode 100644 src/objectives/elbo/elbo.jl create mode 100644 src/objectives/elbo/entropy.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 32b114ba..dfb22930 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -95,8 +95,6 @@ export TruncatedADAGrad, DecayedADAGrad -abstract type VariationalObjective end - const VariationalPosterior = Distribution{Multivariate, Continuous} @@ -160,9 +158,11 @@ function grad!( end # estimators -abstract type AbstractGradientEstimator end +abstract type AbstractVariationalObjective end -include("estimators/advi.jl") +include("objectives/elbo/elbo.jl") +include("objectives/elbo/advi_energy.jl") +include("objectives/elbo/entropy.jl") # optimisers include("optimisers.jl") diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl deleted file mode 100644 index 701ec1ef..00000000 --- a/src/estimators/advi.jl +++ /dev/null @@ -1,55 +0,0 @@ - -struct ADVI{Tlogπ, B} <: AbstractGradientEstimator - # Automatic differentiation variational inference - # - # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). - # Automatic differentiation variational inference. - # Journal of machine learning research. - - ℓπ::Tlogπ - b⁻¹::B - n_samples::Int - - function ADVI(prob, b⁻¹, n_samples; kwargs...) - # Could check whether the support of b⁻¹ and ℓπ match - cap = LogDensityProblems.capabilities(prob) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹, n_samples) - end -end - -ADVI(prob, n_samples; kwargs...) = ADVI(prob, identity, n_samples; kwargs...) - -objective(::ADVI) = "ELBO" - -function estimate_gradient!( - rng::Random.AbstractRNG, - estimator::ADVI, - λ::Vector{<:Real}, - rebuild, - out::DiffResults.MutableDiffResult) - - n_samples = estimator.n_samples - - grad!(ADBackend(), λ, out) do λ′ - q_η = rebuild(λ′) - ηs = rand(rng, q_η, estimator.n_samples) - - 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηᵢ) - (estimator.ℓπ(zᵢ) + logdetjacᵢ) / n_samples - end - - elbo = 𝔼ℓ + entropy(q_η) - -elbo - end - nelbo = DiffResults.value(out) - (elbo=-nelbo,) -end diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl new file mode 100644 index 00000000..b27b752e --- /dev/null +++ b/src/objectives/elbo/advi_energy.jl @@ -0,0 +1,35 @@ + +struct ADVIEnergy{Tlogπ, B} <: AbstractEnergyEstimator + # Automatic differentiation variational inference + # + # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). + # Automatic differentiation variational inference. + # Journal of machine learning research. + + ℓπ::Tlogπ + b⁻¹::B + + function ADVIEnergy(prob, b⁻¹) + # Could check whether the support of b⁻¹ and ℓπ match + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹) + end +end + +ADVIEnergy(prob) = ADVIEnergy(prob, identity) + +function (energy::ADVIEnergy)(q, ηs::AbstractMatrix) + n_samples = size(ηs, 2) + mapreduce(+, eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(energy.b⁻¹, ηᵢ) + (energy.ℓπ(zᵢ) + logdetjacᵢ) / n_samples + end +end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl new file mode 100644 index 00000000..2954ae8e --- /dev/null +++ b/src/objectives/elbo/elbo.jl @@ -0,0 +1,44 @@ + +abstract type AbstractEnergyEstimator end +abstract type AbstractEntropyEstimator end + +struct ELBO{EnergyEst <: AbstractEnergyEstimator, + EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + # Evidence Lower Bound + # + # Jordan, Michael I., et al. + # "An introduction to variational methods for graphical models." + # Machine learning 37 (1999): 183-233. + + energy_estimator::EnergyEst + entropy_estimator::EntropyEst + n_samples::Int +end + +Base.string(::ELBO) = "ELBO" + +function ADVI(ℓπ, b⁻¹, n_samples::Int) + ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) +end + +function estimate_gradient!( + rng::Random.AbstractRNG, + objective::ELBO, + λ::Vector{<:Real}, + rebuild, + out::DiffResults.MutableDiffResult) + + n_samples = objective.n_samples + + grad!(ADBackend(), λ, out) do λ′ + q_η = rebuild(λ′) + ηs = rand(rng, q_η, n_samples) + + 𝔼ℓ = objective.energy_estimator(q_η, ηs) + ℍ = objective.entropy_estimator(q_η, ηs) + elbo = 𝔼ℓ + ℍ + -elbo + end + nelbo = DiffResults.value(out) + (elbo=-nelbo,) +end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl new file mode 100644 index 00000000..d7fb7054 --- /dev/null +++ b/src/objectives/elbo/entropy.jl @@ -0,0 +1,18 @@ + +struct ClosedFormEntropy <: AbstractEntropyEstimator +end + +function (::ClosedFormEntropy)(q, ηs::AbstractMatrix) + entropy(q) +end + +struct MonteCarloEntropy <: AbstractEntropyEstimator +end + +function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) + n_samples = size(ηs, 2) + mapreduce(+, eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) / n_samples + end +end + diff --git a/src/vi.jl b/src/vi.jl index aceb3f2d..4bf4595f 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -10,32 +10,30 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer the steps. """ function optimize( - grad_estimator::AbstractGradientEstimator, + objective::AbstractVariationalObjective, rebuild::Function, n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG ) - obj_name = objective(grad_estimator) - # TODO: really need a better way to warn the user about potentially # not using the correct accumulator if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) # this message should only occurr once in the optimization process - @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) + @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end grad_buf = DiffResults.GradientResult(λ) i = 0 prog = ProgressMeter.Progress( - n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) + n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) + stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) # apply update rule Δλ = DiffResults.gradient(grad_buf) From 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 9 Jun 2023 01:18:02 +0100 Subject: [PATCH 013/206] migrate to AbstractDifferentiation --- Project.toml | 3 +- src/AdvancedVI.jl | 101 ++---------------------------------- src/objectives/elbo/elbo.jl | 10 ++-- src/vi.jl | 8 ++- 4 files changed, 13 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index e73037ec..6964c135 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.3" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -15,7 +16,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] Bijectors = "0.11, 0.12" @@ -27,7 +27,6 @@ ProgressMeter = "1.0.0" Requires = "0.5, 1.0" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" julia = "1.6" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index dfb22930..809d86c6 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -9,14 +9,16 @@ using ProgressMeter, LinearAlgebra using LogDensityProblems -using ForwardDiff -using Tracker - using Distributions using DistributionsAD using StatsFuns +using ForwardDiff +import AbstractDifferentiation as AD + +value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...) + const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") @@ -35,58 +37,6 @@ function __init__() Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ZygoteAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - y, back = Zygote.pullback(f, λ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ReverseDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - tp = AdvancedVI.tape(f, λ) - ReverseDiff.gradient!(out, tp, λ) - return out - end - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - export EnzymeAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:EnzymeAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(λ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) - return out - end - end end export @@ -97,16 +47,6 @@ export const VariationalPosterior = Distribution{Multivariate, Continuous} - -""" - grad!(f, λ, out) - -Computes the gradients of the objective f. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. -""" -function grad! end - """ vi(model, alg::VariationalInference) vi(model, alg::VariationalInference, q::VariationalPosterior) @@ -126,37 +66,6 @@ function vi end function update end -# default implementations -function grad!( - f::Function, - adtype::Type{<:ForwardDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(adtype) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, λ) - else - ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) - end - ForwardDiff.gradient!(out, f, λ, config) -end - -function grad!( - f::Function, - ::Type{<:TrackerAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - λ_tracked = Tracker.param(λ) - y = f(λ_tracked) - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(λ_tracked)) -end - # estimators abstract type AbstractVariationalObjective end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 2954ae8e..213cc725 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -22,15 +22,14 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) end function estimate_gradient!( + adbackend::AD.AbstractBackend, rng::Random.AbstractRNG, objective::ELBO, λ::Vector{<:Real}, - rebuild, - out::DiffResults.MutableDiffResult) + rebuild) n_samples = objective.n_samples - - grad!(ADBackend(), λ, out) do λ′ + nelbo, grad = value_and_gradient(λ; adbackend) do λ′ q_η = rebuild(λ′) ηs = rand(rng, q_η, n_samples) @@ -39,6 +38,5 @@ function estimate_gradient!( elbo = 𝔼ℓ + ℍ -elbo end - nelbo = DiffResults.value(out) - (elbo=-nelbo,) + first(grad), (elbo=-nelbo,) end diff --git a/src/vi.jl b/src/vi.jl index 4bf4595f..7b7858b8 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -15,7 +15,8 @@ function optimize( n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), - rng = Random.GLOBAL_RNG + rng = Random.default_rng(), + adbackend = AD.ForwardDiffBackend() ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -24,8 +25,6 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end - grad_buf = DiffResults.GradientResult(λ) - i = 0 prog = ProgressMeter.Progress( n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) @@ -33,10 +32,9 @@ function optimize( # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild) # apply update rule - Δλ = DiffResults.gradient(grad_buf) Δλ = apply!(optimizer, λ, Δλ) @. λ = λ - Δλ From 93a16d8bc6aac9725081ea4c414ffd9343e6e79e Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 00:42:36 +0100 Subject: [PATCH 014/206] add location scale pre-packaged variational family, add functors --- Project.toml | 2 ++ src/AdvancedVI.jl | 19 +++++++++++++---- src/distributions/location_scale.jl | 33 +++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 src/distributions/location_scale.jl diff --git a/Project.toml b/Project.toml index 6964c135..88342f19 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" @@ -16,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] Bijectors = "0.11, 0.12" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 809d86c6..8c33f74a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -2,6 +2,8 @@ module AdvancedVI using Random: Random +using Functors + using Distributions, DistributionsAD, Bijectors using DocStringExtensions @@ -13,8 +15,9 @@ using Distributions using DistributionsAD using StatsFuns +import StatsBase: entropy -using ForwardDiff +using ForwardDiff, Tracker import AbstractDifferentiation as AD value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...) @@ -40,13 +43,18 @@ function __init__() end export - vi, + optimize, + ELBO, ADVI, + ADVIEnergy, + ClosedFormEntropy, + MonteCarloEntropy, + LocationScale, + FullRankGaussian, + MeanFieldGaussian, TruncatedADAGrad, DecayedADAGrad -const VariationalPosterior = Distribution{Multivariate, Continuous} - """ vi(model, alg::VariationalInference) vi(model, alg::VariationalInference, q::VariationalPosterior) @@ -73,6 +81,9 @@ include("objectives/elbo/elbo.jl") include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") +# Variational Families +include("distributions/location_scale.jl") + # optimisers include("optimisers.jl") diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl new file mode 100644 index 00000000..3aba53c5 --- /dev/null +++ b/src/distributions/location_scale.jl @@ -0,0 +1,33 @@ + +LocationScale(μ::LinearAlgebra.AbstractVector, + L::Union{<: LinearAlgebra.AbstractTriangular, + <: LinearAlgebra.Diagonal}, + q₀::Distributions.ContinuousMultivariateDistribution) = + transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) + +function location_scale_entropy( + q₀::Distributions.ContinuousMultivariateDistribution, + locscale_bijector::Bijectors.ComposedFunction) +end + +function entropy(q_trans::MultivariateTransformed{ + <: Distributions.ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}) + q_base = q_trans.dist + scale = q_trans.transform.inner.a + entropy(q_base) + first(logabsdet(scale)) +end + +function FullRankGaussian(μ::AbstractVector, + L::LinearAlgebra.AbstractTriangular) + q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) + LocationScale(μ, L, q₀) +end + +function MeanFieldGaussian(μ::AbstractVector, + L::LinearAlgebra.Diagonal) + q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) + LocationScale(μ, L, q₀) +end From 2b6e9ebed556dd67bb9325a5b04228637e1e03df Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:04:19 +0100 Subject: [PATCH 015/206] Revert "migrate to AbstractDifferentiation" This reverts commit 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691. --- Project.toml | 2 +- src/AdvancedVI.jl | 101 ++++++++++++++++++++++++++++++++++-- src/objectives/elbo/elbo.jl | 10 ++-- src/vi.jl | 8 +-- 4 files changed, 108 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 88342f19..9a3303f5 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.3" [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -29,6 +28,7 @@ ProgressMeter = "1.0.0" Requires = "0.5, 1.0" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" +Tracker = "0.2.3" julia = "1.6" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8c33f74a..116bb63c 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -11,17 +11,15 @@ using ProgressMeter, LinearAlgebra using LogDensityProblems +using ForwardDiff +using Tracker + using Distributions using DistributionsAD using StatsFuns import StatsBase: entropy -using ForwardDiff, Tracker -import AbstractDifferentiation as AD - -value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...) - const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") @@ -40,6 +38,58 @@ function __init__() Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("compat/zygote.jl") + export ZygoteAD + + function AdvancedVI.grad!( + f::Function, + ::Type{<:ZygoteAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + y, back = Zygote.pullback(f, λ) + dy = first(back(1.0)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, dy) + return out + end + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("compat/reversediff.jl") + export ReverseDiffAD + + function AdvancedVI.grad!( + f::Function, + ::Type{<:ReverseDiffAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + tp = AdvancedVI.tape(f, λ) + ReverseDiff.gradient!(out, tp, λ) + return out + end + end + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("compat/enzyme.jl") + export EnzymeAD + + function AdvancedVI.grad!( + f::Function, + ::Type{<:EnzymeAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + # Use `Enzyme.ReverseWithPrimal` once it is released: + # https://github.com/EnzymeAD/Enzyme.jl/pull/598 + y = f(λ) + DiffResults.value!(out, y) + dy = DiffResults.gradient(out) + fill!(dy, 0) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) + return out + end + end end export @@ -55,6 +105,16 @@ export TruncatedADAGrad, DecayedADAGrad + +""" + grad!(f, λ, out) + +Computes the gradients of the objective f. Default implementation is provided for +`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. +This implicitly also gives a default implementation of `optimize!`. +""" +function grad! end + """ vi(model, alg::VariationalInference) vi(model, alg::VariationalInference, q::VariationalPosterior) @@ -74,6 +134,37 @@ function vi end function update end +# default implementations +function grad!( + f::Function, + adtype::Type{<:ForwardDiffAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) + # Set chunk size and do ForwardMode. + chunk_size = getchunksize(adtype) + config = if chunk_size == 0 + ForwardDiff.GradientConfig(f, λ) + else + ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) + end + ForwardDiff.gradient!(out, f, λ, config) +end + +function grad!( + f::Function, + ::Type{<:TrackerAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) + λ_tracked = Tracker.param(λ) + y = f(λ_tracked) + Tracker.back!(y, 1.0) + + DiffResults.value!(out, Tracker.data(y)) + DiffResults.gradient!(out, Tracker.grad(λ_tracked)) +end + # estimators abstract type AbstractVariationalObjective end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 213cc725..2954ae8e 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -22,14 +22,15 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) end function estimate_gradient!( - adbackend::AD.AbstractBackend, rng::Random.AbstractRNG, objective::ELBO, λ::Vector{<:Real}, - rebuild) + rebuild, + out::DiffResults.MutableDiffResult) n_samples = objective.n_samples - nelbo, grad = value_and_gradient(λ; adbackend) do λ′ + + grad!(ADBackend(), λ, out) do λ′ q_η = rebuild(λ′) ηs = rand(rng, q_η, n_samples) @@ -38,5 +39,6 @@ function estimate_gradient!( elbo = 𝔼ℓ + ℍ -elbo end - first(grad), (elbo=-nelbo,) + nelbo = DiffResults.value(out) + (elbo=-nelbo,) end diff --git a/src/vi.jl b/src/vi.jl index 7b7858b8..4bf4595f 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -15,8 +15,7 @@ function optimize( n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), - rng = Random.default_rng(), - adbackend = AD.ForwardDiffBackend() + rng = Random.GLOBAL_RNG ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -25,6 +24,8 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end + grad_buf = DiffResults.GradientResult(λ) + i = 0 prog = ProgressMeter.Progress( n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) @@ -32,9 +33,10 @@ function optimize( # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild) + stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) # apply update rule + Δλ = DiffResults.gradient(grad_buf) Δλ = apply!(optimizer, λ, Δλ) @. λ = λ - Δλ From 1bfec36961c437cf000234bd29504fd49848d676 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:41:25 +0100 Subject: [PATCH 016/206] fix use optimized MvNormal specialization, add logpdf for Loc.Scale. --- Project.toml | 2 + src/AdvancedVI.jl | 23 +++++++----- src/distributions/location_scale.jl | 57 +++++++++++++++++++---------- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/Project.toml b/Project.toml index 9a3303f5..38a5026a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,10 +7,12 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 116bb63c..d5a06fce 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,20 +4,23 @@ using Random: Random using Functors -using Distributions, DistributionsAD, Bijectors using DocStringExtensions -using ProgressMeter, LinearAlgebra +using ProgressMeter +using LinearAlgebra +using LinearAlgebra: AbstractTriangular using LogDensityProblems using ForwardDiff using Tracker -using Distributions -using DistributionsAD +using FillArrays +using PDMats +using Distributions, DistributionsAD +using Distributions: ContinuousMultivariateDistribution +using Bijectors -using StatsFuns import StatsBase: entropy const PROGRESS = Ref(true) @@ -29,7 +32,6 @@ end const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) include("ad.jl") -include("utils.jl") using Requires function __init__() @@ -116,9 +118,9 @@ This implicitly also gives a default implementation of `optimize!`. function grad! end """ - vi(model, alg::VariationalInference) - vi(model, alg::VariationalInference, q::VariationalPosterior) - vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) + optimize(model, alg::VariationalInference) + optimize(model, alg::VariationalInference, q::VariationalPosterior) + optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) Constructs the variational posterior from the `model` and performs the optimization following the configuration of the given `VariationalInference` instance. @@ -130,7 +132,7 @@ following the configuration of the given `VariationalInference` instance. - `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` - `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior """ -function vi end +function optimize end function update end @@ -178,6 +180,7 @@ include("distributions/location_scale.jl") # optimisers include("optimisers.jl") +include("utils.jl") include("vi.jl") end # module diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 3aba53c5..365ae15e 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,33 +1,52 @@ -LocationScale(μ::LinearAlgebra.AbstractVector, - L::Union{<: LinearAlgebra.AbstractTriangular, - <: LinearAlgebra.Diagonal}, - q₀::Distributions.ContinuousMultivariateDistribution) = - transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) +function LocationScale(μ::AbstractVector, + L::Union{<: AbstractTriangular, + <: Diagonal}, + q₀::ContinuousMultivariateDistribution) + @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) + transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) +end function location_scale_entropy( - q₀::Distributions.ContinuousMultivariateDistribution, + q₀::ContinuousMultivariateDistribution, locscale_bijector::Bijectors.ComposedFunction) end -function entropy(q_trans::MultivariateTransformed{ - <: Distributions.ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}) +function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}) q_base = q_trans.dist scale = q_trans.transform.inner.a entropy(q_base) + first(logabsdet(scale)) end -function FullRankGaussian(μ::AbstractVector, - L::LinearAlgebra.AbstractTriangular) - q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) - LocationScale(μ, L, q₀) +function logpdf(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}, + z::AbstractVector) + q_base = q_trans.dist + reparam = q_trans.transform + scale = q_trans.transform.inner.a + η = inverse(reparam)(z) + logpdf(q_base, η) - first(logabsdet(scale)) +end + +function FullRankGaussian(μ::AbstractVector{T}, + L::AbstractTriangular{T,S}) where {T <: Real, S} + @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) + n_dims = length(μ) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), + PDMats.ScalMat{T}(n_dims, one(T))) + LocationScale(μ, L, q_base) end -function MeanFieldGaussian(μ::AbstractVector, - L::LinearAlgebra.Diagonal) - q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) - LocationScale(μ, L, q₀) +function MeanFieldGaussian(μ::AbstractVector{T}, + L::Diagonal{T,V}) where {T <: Real, V} + @assert (length(μ) == size(L,1)) + n_dims = length(μ) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), + PDMats.ScalMat{T}(n_dims, one(T))) + LocationScale(μ, L, q_base) end From 1003606283efd6b8cf340e74dced65d8ea72b296 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:52:53 +0100 Subject: [PATCH 017/206] remove dead code --- src/distributions/location_scale.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 365ae15e..1f7bad85 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -7,11 +7,6 @@ function LocationScale(μ::AbstractVector, transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) end -function location_scale_entropy( - q₀::ContinuousMultivariateDistribution, - locscale_bijector::Bijectors.ComposedFunction) -end - function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, <: Bijectors.ComposedFunction{ <: Bijectors.Shift, From 60a9987ed259b906da9cdd6e38ed33102497f389 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:56:30 +0100 Subject: [PATCH 018/206] fix location-scale logpdf - Full Monte Carlo ELBO estimation now works. I checked. --- src/AdvancedVI.jl | 3 ++- src/distributions/location_scale.jl | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d5a06fce..9b9d3ab2 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -21,7 +21,8 @@ using Distributions, DistributionsAD using Distributions: ContinuousMultivariateDistribution using Bijectors -import StatsBase: entropy +using StatsBase +using StatsBase: entropy const PROGRESS = Ref(true) function turnprogress(switch::Bool) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 1f7bad85..dd9b5f2a 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -7,20 +7,22 @@ function LocationScale(μ::AbstractVector, transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) end -function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}) +function StatsBase.entropy( + q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}) q_base = q_trans.dist scale = q_trans.transform.inner.a entropy(q_base) + first(logabsdet(scale)) end -function logpdf(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}, - z::AbstractVector) +function Distributions.logpdf( + q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}, + z::AbstractVector) q_base = q_trans.dist reparam = q_trans.transform scale = q_trans.transform.inner.a From cd84f02898d7cf82c530f98a91579f0b01935f33 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 22:21:22 +0100 Subject: [PATCH 019/206] add sticking-the-landing (STL) estimator --- src/objectives/elbo/elbo.jl | 36 ++++++++++++++++++++++++---------- src/objectives/elbo/entropy.jl | 35 ++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 2954ae8e..343581d8 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -21,23 +21,39 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) end +function (elbo::ELBO)(q_η::ContinuousMultivariateDistribution; + rng = Random.default_rng(), + n_samples::Int = elbo.n_samples, + q_η_entropy::ContinuousMultivariateDistribution = q_η) + ηs = rand(rng, q_η, n_samples) + 𝔼ℓ = elbo.energy_estimator(q_η, ηs) + ℍ = elbo.entropy_estimator(q_η_entropy, ηs) + 𝔼ℓ + ℍ +end + function estimate_gradient!( rng::Random.AbstractRNG, - objective::ELBO, + elbo::ELBO{EnergyEst, EntropyEst}, λ::Vector{<:Real}, rebuild, - out::DiffResults.MutableDiffResult) - - n_samples = objective.n_samples + out::DiffResults.MutableDiffResult) where {EnergyEst <: AbstractEnergyEstimator, + EntropyEst <: AbstractEntropyEstimator} + + # Gradient-stopping for computing the sticking-the-landing control variate + q_η_stop = if EntropyEst isa MonteCarloEntropy{true} + rebuild(λ) + else + nothing + end grad!(ADBackend(), λ, out) do λ′ q_η = rebuild(λ′) - ηs = rand(rng, q_η, n_samples) - - 𝔼ℓ = objective.energy_estimator(q_η, ηs) - ℍ = objective.entropy_estimator(q_η, ηs) - elbo = 𝔼ℓ + ℍ - -elbo + q_η_entropy = if EntropyEst isa MonteCarloEntropy{true} + q_η_stop + else + q_η + end + -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) end nelbo = DiffResults.value(out) (elbo=-nelbo,) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index d7fb7054..8efb7c71 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,13 +1,38 @@ -struct ClosedFormEntropy <: AbstractEntropyEstimator -end +struct ClosedFormEntropy <: AbstractEntropyEstimator end -function (::ClosedFormEntropy)(q, ηs::AbstractMatrix) +function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end -struct MonteCarloEntropy <: AbstractEntropyEstimator -end +struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end + +MonteCarloEntropy() = MonteCarloEntropy{false}() + +""" + Sticking the Landing Control Variate + + # Explanation + + This eatimator forms a control variate of the form of + + c(z) = 𝔼-logq(z) + logq(z) = ℍ[q] - logq(z) + + Adding this to the closed-form entropy ELBO estimator yields: + + ELBO - c(z) = 𝔼logπ(z) + ℍ[q] - c(z) = 𝔼logπ(z) - logq(z), + + which has the same expectation, but lower variance when π ≈ q, + and higher variance when π ≉ q. + + # Reference + + Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud. + "Sticking the landing: Simple, lower-variance gradient estimators for + variational inference." + Advances in Neural Information Processing Systems 30 (2017). +""" +StickingTheLandingEntropy() = MonteCarloEntropy{true}() function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) From 768641b1979f4e63125780e53f48e21794bbcdd2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 22:41:50 +0100 Subject: [PATCH 020/206] migrate to Optimisers.jl --- Project.toml | 1 + src/AdvancedVI.jl | 11 +++-------- src/vi.jl | 27 ++++++++++++++++----------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 38a5026a..ba807698 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9b9d3ab2..5a02501b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,6 +4,8 @@ using Random: Random using Functors +using Optimisers + using DocStringExtensions using ProgressMeter @@ -12,8 +14,7 @@ using LinearAlgebra: AbstractTriangular using LogDensityProblems -using ForwardDiff -using Tracker +using ForwardDiff, Tracker using FillArrays using PDMats @@ -24,12 +25,6 @@ using Bijectors using StatsBase using StatsBase: entropy -const PROGRESS = Ref(true) -function turnprogress(switch::Bool) - @info("[AdvancedVI]: global PROGRESS is set as $switch") - PROGRESS[] = switch -end - const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) include("ad.jl") diff --git a/src/vi.jl b/src/vi.jl index 4bf4595f..6c8b26d1 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -10,12 +10,13 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer the steps. """ function optimize( - objective::AbstractVariationalObjective, - rebuild::Function, + objective ::AbstractVariationalObjective, + rebuild, n_max_iter::Int, - λ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad(), - rng = Random.GLOBAL_RNG + λ ::AbstractVector{<:Real}; + optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), + rng ::Random.AbstractRNG = Random.GLOBAL_RNG, + progress ::Bool = true ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -24,21 +25,25 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end + optstate = Optimisers.init(optimizer, λ) grad_buf = DiffResults.GradientResult(λ) i = 0 prog = ProgressMeter.Progress( - n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) + n_max_iter; + desc = "[$(string(objective))] Optimizing...", + barlen = 0, + enabled = progress, + showspeed = true) # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) - - # apply update rule - Δλ = DiffResults.gradient(grad_buf) - Δλ = apply!(optimizer, λ, Δλ) - @. λ = λ - Δλ + g = DiffResults.gradient(grad_buf) + + optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) + Optimisers.subtract!(λ, Δλ) stat′ = (Δλ=norm(Δλ),) stats = merge(stats, stat′) From ca02fa315486a0977327f3e2824cd87b40b1908a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 22:42:38 +0100 Subject: [PATCH 021/206] remove execution time measurement (replace later with somethin else) --- src/vi.jl | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/vi.jl b/src/vi.jl index 6c8b26d1..e5062def 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -36,23 +36,20 @@ function optimize( enabled = progress, showspeed = true) - # add criterion? A running mean maybe? - time_elapsed = @elapsed begin - for i = 1:n_max_iter - stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) - g = DiffResults.gradient(grad_buf) + for i = 1:n_max_iter + stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + g = DiffResults.gradient(grad_buf) - optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) - Optimisers.subtract!(λ, Δλ) + optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) + Optimisers.subtract!(λ, Δλ) - stat′ = (Δλ=norm(Δλ),) - stats = merge(stats, stat′) + stat′ = (Δλ=norm(Δλ),) + stats = merge(stats, stat′) - AdvancedVI.DEBUG && @debug "Step $i" stats... + AdvancedVI.DEBUG && @debug "Step $i" stats... pm_next!(prog, stats) - end end - return λ + λ end # function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG) From a48377f016c82461000ba10c35803a5181f4b4a9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 21:47:22 +0100 Subject: [PATCH 022/206] fix use multiple dispatch for deciding whether to stop entropy grad. --- src/objectives/elbo/elbo.jl | 21 +++++++-------------- src/objectives/elbo/entropy.jl | 4 ++++ src/vi.jl | 4 ++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 343581d8..cebd7d82 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -15,6 +15,8 @@ struct ELBO{EnergyEst <: AbstractEnergyEstimator, n_samples::Int end +skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator) + Base.string(::ELBO) = "ELBO" function ADVI(ℓπ, b⁻¹, n_samples::Int) @@ -33,28 +35,19 @@ end function estimate_gradient!( rng::Random.AbstractRNG, - elbo::ELBO{EnergyEst, EntropyEst}, + elbo::ELBO, λ::Vector{<:Real}, rebuild, - out::DiffResults.MutableDiffResult) where {EnergyEst <: AbstractEnergyEstimator, - EntropyEst <: AbstractEntropyEstimator} + out::DiffResults.MutableDiffResult) # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = if EntropyEst isa MonteCarloEntropy{true} - rebuild(λ) - else - nothing - end + q_η_stop = skip_entropy_gradient(elbo) ? rebuild(λ) : nothing grad!(ADBackend(), λ, out) do λ′ q_η = rebuild(λ′) - q_η_entropy = if EntropyEst isa MonteCarloEntropy{true} - q_η_stop - else - q_η - end + q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) end nelbo = DiffResults.value(out) - (elbo=-nelbo,) + out, (elbo=-nelbo,) end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 8efb7c71..50f498d6 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,6 +5,8 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end +skip_entropy_gradient(::ClosedFormEntropy) = false + struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end MonteCarloEntropy() = MonteCarloEntropy{false}() @@ -34,6 +36,8 @@ MonteCarloEntropy() = MonteCarloEntropy{false}() """ StickingTheLandingEntropy() = MonteCarloEntropy{true}() +skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding + function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) mapreduce(+, eachcol(ηs)) do ηᵢ diff --git a/src/vi.jl b/src/vi.jl index e5062def..8b8fe14f 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -37,8 +37,8 @@ function optimize( showspeed = true) for i = 1:n_max_iter - stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) - g = DiffResults.gradient(grad_buf) + grad_buf, stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + g = DiffResults.gradient(grad_buf) optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) Optimisers.subtract!(λ, Δλ) From 0b40ccf6ef10e6ebef9d6372e407731bb4dc2ca0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 22:16:30 +0100 Subject: [PATCH 023/206] add termination decision, callback arguments --- src/vi.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/vi.jl b/src/vi.jl index 8b8fe14f..1a4d57ec 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -16,7 +16,9 @@ function optimize( λ ::AbstractVector{<:Real}; optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), rng ::Random.AbstractRNG = Random.GLOBAL_RNG, - progress ::Bool = true + progress ::Bool = true, + callback! = nothing, + terminate = (args...) -> false, ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -28,6 +30,7 @@ function optimize( optstate = Optimisers.init(optimizer, λ) grad_buf = DiffResults.GradientResult(λ) + q = rebuild(λ) i = 0 prog = ProgressMeter.Progress( n_max_iter; @@ -43,11 +46,22 @@ function optimize( optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) Optimisers.subtract!(λ, Δλ) - stat′ = (Δλ=norm(Δλ),) + stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g)) stats = merge(stats, stat′) + q = rebuild(λ) + + if !isnothing(callback!) + stat′ = callback!(q, stats) + stats = !isnothing(stat′) ? merge(stat′, stats) : stats + end AdvancedVI.DEBUG && @debug "Step $i" stats... pm_next!(prog, stats) + + # Termination decision is work in progress + if terminate(rng, q, objective, stats) + break + end end λ end From 21db3fb842d226148ee23b758c0756e332132066 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 22:35:48 +0100 Subject: [PATCH 024/206] add Base.show to modules --- src/objectives/elbo/advi_energy.jl | 2 ++ src/objectives/elbo/elbo.jl | 6 +++++- src/objectives/elbo/entropy.jl | 4 ++++ src/vi.jl | 1 - 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl index b27b752e..078a157e 100644 --- a/src/objectives/elbo/advi_energy.jl +++ b/src/objectives/elbo/advi_energy.jl @@ -26,6 +26,8 @@ end ADVIEnergy(prob) = ADVIEnergy(prob, identity) +Base.show(io::IO, energy::ADVIEnergy) = print(io, "ADVIEnergy()") + function (energy::ADVIEnergy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) mapreduce(+, eachcol(ηs)) do ηᵢ diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index cebd7d82..b26516d9 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -17,7 +17,11 @@ end skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator) -Base.string(::ELBO) = "ELBO" +Base.show(io::IO, elbo::ELBO) = print( + io, + "ELBO(energy_estimator=$(elbo.energy_estimator), " * + "entropy_estimator=$(elbo.entropy_estimator)), " * + "n_samples=$(elbo.n_samples))") function ADVI(ℓπ, b⁻¹, n_samples::Int) ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 50f498d6..ddeb64a9 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -11,6 +11,8 @@ struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end MonteCarloEntropy() = MonteCarloEntropy{false}() +Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()") + """ Sticking the Landing Control Variate @@ -38,6 +40,8 @@ StickingTheLandingEntropy() = MonteCarloEntropy{true}() skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding +Base.show(io::IO, entropy::MonteCarloEntropy{true}) = print(io, "StickingTheLandingEntropy()") + function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) mapreduce(+, eachcol(ηs)) do ηᵢ diff --git a/src/vi.jl b/src/vi.jl index 1a4d57ec..605464b6 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -34,7 +34,6 @@ function optimize( i = 0 prog = ProgressMeter.Progress( n_max_iter; - desc = "[$(string(objective))] Optimizing...", barlen = 0, enabled = progress, showspeed = true) From 25c51b4796b2e550d1ee9747e5ccbf81a48aff38 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 23:03:25 +0100 Subject: [PATCH 025/206] add interface calling `restructure`, rename rebuild -> restructure --- src/objectives/elbo/elbo.jl | 6 ++-- src/vi.jl | 61 ++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index b26516d9..b3bad3c0 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -41,14 +41,14 @@ function estimate_gradient!( rng::Random.AbstractRNG, elbo::ELBO, λ::Vector{<:Real}, - rebuild, + restructure, out::DiffResults.MutableDiffResult) # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(elbo) ? rebuild(λ) : nothing + q_η_stop = skip_entropy_gradient(elbo) ? restructure(λ) : nothing grad!(ADBackend(), λ, out) do λ′ - q_η = rebuild(λ′) + q_η = restructure(λ′) q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) end diff --git a/src/vi.jl b/src/vi.jl index 605464b6..f1f4bc25 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -11,9 +11,9 @@ the steps. """ function optimize( objective ::AbstractVariationalObjective, - rebuild, - n_max_iter::Int, - λ ::AbstractVector{<:Real}; + restructure, + λ ::AbstractVector{<:Real}, + n_max_iter::Int; optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), rng ::Random.AbstractRNG = Random.GLOBAL_RNG, progress ::Bool = true, @@ -30,50 +30,47 @@ function optimize( optstate = Optimisers.init(optimizer, λ) grad_buf = DiffResults.GradientResult(λ) - q = rebuild(λ) - i = 0 - prog = ProgressMeter.Progress( - n_max_iter; - barlen = 0, - enabled = progress, - showspeed = true) + prog = ProgressMeter.Progress(n_max_iter; + barlen = 0, + enabled = progress, + showspeed = true) + stats = Vector{NamedTuple}(undef, n_max_iter) - for i = 1:n_max_iter - grad_buf, stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + for t = 1:n_max_iter + grad_buf, stat = estimate_gradient!(rng, objective, λ, restructure, grad_buf) g = DiffResults.gradient(grad_buf) optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) Optimisers.subtract!(λ, Δλ) stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g)) - stats = merge(stats, stat′) - q = rebuild(λ) + stat = merge(stat, stat′) + q = restructure(λ) if !isnothing(callback!) - stat′ = callback!(q, stats) - stats = !isnothing(stat′) ? merge(stat′, stats) : stats + stat′ = callback!(q, stat) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - AdvancedVI.DEBUG && @debug "Step $i" stats... - pm_next!(prog, stats) + AdvancedVI.DEBUG && @debug "Step $i" stat... + + pm_next!(prog, stat) + stats[t] = stat # Termination decision is work in progress - if terminate(rng, q, objective, stats) + if terminate(rng, q, objective, stat) + stats = stats[1:t] break end end - λ + λ, stats end -# function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG) -# θ = copy(θ_init) -# optimize!(grad_estimator, rebuild, n_max_iter, λ, optimizer = optimizer, rng = rng) - -# # If `q` is a mean-field approx we use the specialized `update` function -# if q isa Distribution -# return update(q, θ) -# else -# # Otherwise we assume it's a mapping θ → q -# return q(θ) -# end -# end +function optimize(objective::AbstractVariationalObjective, + q, + n_max_iter::Int; + kwargs...) + λ, restructure = Optimisers.destructure(q) + λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...) + restructure(λ), stats +end From fc200462e0a6929ca580d6cabaad27afd179b30f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:33:47 +0100 Subject: [PATCH 026/206] add estimator state interface, add control variate interface to ADVI --- src/AdvancedVI.jl | 12 ++++++- src/objectives/elbo/advi.jl | 64 +++++++++++++++++++++++++++++++++++++ src/objectives/elbo/elbo.jl | 57 --------------------------------- src/vi.jl | 22 +++++++------ 4 files changed, 88 insertions(+), 67 deletions(-) create mode 100644 src/objectives/elbo/advi.jl delete mode 100644 src/objectives/elbo/elbo.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5a02501b..f2eb2317 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -166,7 +166,17 @@ end # estimators abstract type AbstractVariationalObjective end -include("objectives/elbo/elbo.jl") +function estimate_gradient end + +abstract type AbstractEnergyEstimator end +abstract type AbstractEntropyEstimator end +abstract type AbstractControlVariate end + +init(::Nothing) = nothing + +update(::Nothing, ::Nothing) = (nothing, nothing) + +include("objectives/elbo/advi.jl") include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl new file mode 100644 index 00000000..66e5f320 --- /dev/null +++ b/src/objectives/elbo/advi.jl @@ -0,0 +1,64 @@ + +struct ADVI{EnergyEst <: AbstractEnergyEstimator, + EntropyEst <: AbstractEntropyEstimator, + ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective + energy_estimator::EnergyEst + entropy_estimator::EntropyEst + control_variate::ControlVar + n_samples::Int +end + +skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) + +init(advi::ADVI) = init(advi.control_variate) + +Base.show(io::IO, advi::ADVI) = print( + io, + "ADVI(energy_estimator=$(advi.energy_estimator), " * + "entropy_estimator=$(advi.entropy_estimator)), " * + "n_samples=$(advi.n_samples))") + +function ADVI(energy_estimator::AbstractEnergyEstimator, + entropy_estimator::AbstractEntropyEstimator, + n_samples::Int) + ADVI(energy_estimator, entropy_estimator, nothing, n_samples) +end + +function ADVI(ℓπ, b⁻¹, n_samples::Int) + ADVI(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) +end + +function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; + rng ::Random.AbstractRNG = Random.default_rng(), + n_samples ::Int = advi.n_samples, + ηs ::AbstractMatrix = rand(rng, q_η, n_samples), + q_η_entropy::ContinuousMultivariateDistribution = q_η) + 𝔼ℓ = advi.energy_estimator(q_η, ηs) + ℍ = advi.entropy_estimator(q_η_entropy, ηs) + 𝔼ℓ + ℍ +end + +function estimate_gradient( + rng::Random.AbstractRNG, + advi::ADVI, + est_state, + λ::Vector{<:Real}, + restructure, + out::DiffResults.MutableDiffResult) + + # Gradient-stopping for computing the sticking-the-landing control variate + q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing + + grad!(ADBackend(), λ, out) do λ′ + q_η = restructure(λ′) + q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η + -advi(q_η; rng, q_η_entropy) + end + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + est_state, stat′ = update(advi.control_variate, est_state) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat + + out, est_state, stat +end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl deleted file mode 100644 index b3bad3c0..00000000 --- a/src/objectives/elbo/elbo.jl +++ /dev/null @@ -1,57 +0,0 @@ - -abstract type AbstractEnergyEstimator end -abstract type AbstractEntropyEstimator end - -struct ELBO{EnergyEst <: AbstractEnergyEstimator, - EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - # Evidence Lower Bound - # - # Jordan, Michael I., et al. - # "An introduction to variational methods for graphical models." - # Machine learning 37 (1999): 183-233. - - energy_estimator::EnergyEst - entropy_estimator::EntropyEst - n_samples::Int -end - -skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator) - -Base.show(io::IO, elbo::ELBO) = print( - io, - "ELBO(energy_estimator=$(elbo.energy_estimator), " * - "entropy_estimator=$(elbo.entropy_estimator)), " * - "n_samples=$(elbo.n_samples))") - -function ADVI(ℓπ, b⁻¹, n_samples::Int) - ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) -end - -function (elbo::ELBO)(q_η::ContinuousMultivariateDistribution; - rng = Random.default_rng(), - n_samples::Int = elbo.n_samples, - q_η_entropy::ContinuousMultivariateDistribution = q_η) - ηs = rand(rng, q_η, n_samples) - 𝔼ℓ = elbo.energy_estimator(q_η, ηs) - ℍ = elbo.entropy_estimator(q_η_entropy, ηs) - 𝔼ℓ + ℍ -end - -function estimate_gradient!( - rng::Random.AbstractRNG, - elbo::ELBO, - λ::Vector{<:Real}, - restructure, - out::DiffResults.MutableDiffResult) - - # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(elbo) ? restructure(λ) : nothing - - grad!(ADBackend(), λ, out) do λ′ - q_η = restructure(λ′) - q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η - -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) - end - nelbo = DiffResults.value(out) - out, (elbo=-nelbo,) -end diff --git a/src/vi.jl b/src/vi.jl index f1f4bc25..ebb246be 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -27,8 +27,9 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end - optstate = Optimisers.init(optimizer, λ) - grad_buf = DiffResults.GradientResult(λ) + opt_state = Optimisers.init(optimizer, λ) + est_state = init(objective) + grad_buf = DiffResults.GradientResult(λ) prog = ProgressMeter.Progress(n_max_iter; barlen = 0, @@ -37,22 +38,25 @@ function optimize( stats = Vector{NamedTuple}(undef, n_max_iter) for t = 1:n_max_iter - grad_buf, stat = estimate_gradient!(rng, objective, λ, restructure, grad_buf) - g = DiffResults.gradient(grad_buf) + stat = (iteration=t,) - optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) + grad_buf, est_state, stat′ = estimate_gradient(rng, objective, est_state, λ, restructure, grad_buf) + g = DiffResults.gradient(grad_buf) + stat = merge(stat, stat′) + + opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g) Optimisers.subtract!(λ, Δλ) + stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) + stat = merge(stat, stat′) - stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g)) - stat = merge(stat, stat′) - q = restructure(λ) + q = restructure(λ) if !isnothing(callback!) stat′ = callback!(q, stat) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - AdvancedVI.DEBUG && @debug "Step $i" stat... + AdvancedVI.DEBUG && @debug "Step $t" stat... pm_next!(prog, stat) stats[t] = stat From 6faa807f067ff77856c307ef4baa11865616deae Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:39:05 +0100 Subject: [PATCH 027/206] fix `show(advi)` to show control variate --- src/objectives/elbo/advi.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 66e5f320..de2c683b 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -15,7 +15,8 @@ init(advi::ADVI) = init(advi.control_variate) Base.show(io::IO, advi::ADVI) = print( io, "ADVI(energy_estimator=$(advi.energy_estimator), " * - "entropy_estimator=$(advi.entropy_estimator)), " * + "entropy_estimator=$(advi.entropy_estimator), " * + (!isnothing(advi.control_variate) ? "control_variate=$(advi.control_variate), " : "") * "n_samples=$(advi.n_samples))") function ADVI(energy_estimator::AbstractEnergyEstimator, From 7095d276f5947b855289099a0ce56f2106c8b16c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:39:45 +0100 Subject: [PATCH 028/206] fix simplify `show(advi.control_variate)` --- src/objectives/elbo/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index de2c683b..dc2962ee 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -16,7 +16,7 @@ Base.show(io::IO, advi::ADVI) = print( io, "ADVI(energy_estimator=$(advi.energy_estimator), " * "entropy_estimator=$(advi.entropy_estimator), " * - (!isnothing(advi.control_variate) ? "control_variate=$(advi.control_variate), " : "") * + "control_variate=$(advi.control_variate), " * "n_samples=$(advi.n_samples))") function ADVI(energy_estimator::AbstractEnergyEstimator, From 9169ae262f8ac289d8e7355f8642584e18da3614 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:51:48 +0100 Subject: [PATCH 029/206] fix type piracy by wrapping location-scale bijected distribution --- src/distributions/location_scale.jl | 67 ++++++++++++++++------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index dd9b5f2a..f3c95d0c 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,41 +1,51 @@ -function LocationScale(μ::AbstractVector, - L::Union{<: AbstractTriangular, - <: Diagonal}, - q₀::ContinuousMultivariateDistribution) - @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) +import Base: rand, _rand! + +struct LocationScale{ReparamMvDist <: Bijectors.TransformedDistribution} <: ContinuousMultivariateDistribution + q_trans::ReparamMvDist + + function LocationScale(μ::AbstractVector, + L::Union{<: AbstractTriangular, + <: Diagonal}, + q₀::ContinuousMultivariateDistribution) + @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) + q_trans = transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) + new{typeof(q_trans)}(q_trans) + end + + function LocationScale(q_trans::Bijectors.TransformedDistribution) + new{typeof(q_trans)}(q_trans) + end end -function StatsBase.entropy( - q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}) - q_base = q_trans.dist - scale = q_trans.transform.inner.a +Functors.@functor LocationScale + +Base.length(q::LocationScale) = length(q.q_trans) +Base.size(q::LocationScale) = size(q.q_trans) + +function StatsBase.entropy(q::LocationScale) + q_base = q.q_trans.dist + scale = q.q_trans.transform.inner.a entropy(q_base) + first(logabsdet(scale)) end -function Distributions.logpdf( - q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}, - z::AbstractVector) - q_base = q_trans.dist - reparam = q_trans.transform - scale = q_trans.transform.inner.a - η = inverse(reparam)(z) - logpdf(q_base, η) - first(logabsdet(scale)) -end + +Distributions.logpdf(q::LocationScale, z::AbstractVector) = logpdf(q.q_trans, z) + +_logpdf(q::LocationScale, y::AbstractVector) = _logpdf(q.q_trans, y) + +rand(q::LocationScale) = rand(q.q_trans) + +rand(rng::Random.AbstractRNG, q::LocationScale, num_samples::Int) = rand(rng, q.q_trans, num_samples) + +_rand!(rng::Random.AbstractRNG, q::LocationScale, x::AbstractVector{<:Real}) = _rand!(rng, q.q_trans, x) + function FullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T,S}) where {T <: Real, S} @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), - PDMats.ScalMat{T}(n_dims, one(T))) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) LocationScale(μ, L, q_base) end @@ -43,7 +53,6 @@ function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T,V}) where {T <: Real, V} @assert (length(μ) == size(L,1)) n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), - PDMats.ScalMat{T}(n_dims, one(T))) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) LocationScale(μ, L, q_base) end From 3db73011a430fb3aa5830264be687d860410f483 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 26 Jun 2023 23:01:27 +0100 Subject: [PATCH 030/206] remove old AdvancedVI custom optimizers --- Project.toml | 1 + src/AdvancedVI.jl | 15 +++----- src/optimisers.jl | 94 ----------------------------------------------- src/vi.jl | 11 +----- 4 files changed, 8 insertions(+), 113 deletions(-) delete mode 100644 src/optimisers.jl diff --git a/Project.toml b/Project.toml index ba807698..d2708915 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] Bijectors = "0.11, 0.12" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index f2eb2317..76c6d859 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,9 +1,12 @@ + module AdvancedVI -using Random: Random +using UnPack -using Functors +import Random: AbstractRNG, default_rng +import Distributions: logpdf, _logpdf, rand, _rand!, _rand! +using Functors using Optimisers using DocStringExtensions @@ -31,11 +34,6 @@ include("ad.jl") using Requires function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin - apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) - Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) - Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) - end @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("compat/zygote.jl") export ZygoteAD @@ -183,9 +181,6 @@ include("objectives/elbo/entropy.jl") # Variational Families include("distributions/location_scale.jl") -# optimisers -include("optimisers.jl") - include("utils.jl") include("vi.jl") diff --git a/src/optimisers.jl b/src/optimisers.jl deleted file mode 100644 index 8077f98c..00000000 --- a/src/optimisers.jl +++ /dev/null @@ -1,94 +0,0 @@ -const ϵ = 1e-8 - -""" - TruncatedADAGrad(η=0.1, τ=1.0, n=100) - -Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - τ: constant scale factor - - n: number of previous gradient norms to use in the scaling. -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. - -[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E). -""" -mutable struct TruncatedADAGrad - eta::Float64 - tau::Float64 - n::Int - - iters::IdDict - acc::IdDict -end - -function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) - TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) -end - -function apply!(o::TruncatedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - τ = o.tau - - g² = get!( - o.acc, - x, - [zeros(T, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(Δ)), 1} - i = get!(o.iters, x, 1)::Int - - # Example: suppose i = 12 and o.n = 10 - idx = mod(i - 1, o.n) + 1 # => idx = 2 - - # set the current - @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ - - # TODO: make more efficient and stable - s = sum(g²) - - # increment - o.iters[x] += 1 - - # TODO: increment (but "truncate") - # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1 - - @. Δ *= η / (τ + sqrt(s) + ϵ) -end - -""" - DecayedADAGrad(η=0.1, pre=1.0, post=0.9) - -Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - pre: weight of new gradient norm - - post: weight of histroy of gradient norms -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -mutable struct DecayedADAGrad - eta::Float64 - pre::Float64 - post::Float64 - - acc::IdDict -end - -DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) - -function apply!(o::DecayedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) - @. acc = o.post * acc + o.pre * Δ^2 - @. Δ *= η / (√acc + ϵ) -end diff --git a/src/vi.jl b/src/vi.jl index ebb246be..842f187e 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -14,19 +14,12 @@ function optimize( restructure, λ ::AbstractVector{<:Real}, n_max_iter::Int; - optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), - rng ::Random.AbstractRNG = Random.GLOBAL_RNG, + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + rng ::AbstractRNG = default_rng(), progress ::Bool = true, callback! = nothing, terminate = (args...) -> false, ) - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) - end - opt_state = Optimisers.init(optimizer, λ) est_state = init(objective) grad_buf = DiffResults.GradientResult(λ) From e6a082aadbd3fa92e60fedf5373f2efbb1875ecc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 26 Jun 2023 23:47:04 +0100 Subject: [PATCH 031/206] fix Location Scale to not depend on Bijectors --- src/distributions/location_scale.jl | 101 +++++++++++++++++----------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index f3c95d0c..c46b5111 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,58 +1,79 @@ -import Base: rand, _rand! - -struct LocationScale{ReparamMvDist <: Bijectors.TransformedDistribution} <: ContinuousMultivariateDistribution - q_trans::ReparamMvDist - - function LocationScale(μ::AbstractVector, - L::Union{<: AbstractTriangular, - <: Diagonal}, - q₀::ContinuousMultivariateDistribution) +struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution + location::L + scale ::S + dist ::D + epsilon ::R + + function VILocationScale(μ::AbstractVector{<:Real}, + L::Union{<:AbstractTriangular{<:Real}, + <:Diagonal{<:Real}}, + q_base::ContinuousUnivariateDistribution, + epsilon::Real) + # Restricting all the arguments to have the same types creates problems + # with dual-variable-based AD frameworks. @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - q_trans = transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) - new{typeof(q_trans)}(q_trans) - end - - function LocationScale(q_trans::Bijectors.TransformedDistribution) - new{typeof(q_trans)}(q_trans) + new{typeof(μ), typeof(L), typeof(q_base), typeof(epsilon)}(μ, L, q_base, epsilon) end end -Functors.@functor LocationScale +Functors.@functor VILocationScale (location, scale) -Base.length(q::LocationScale) = length(q.q_trans) -Base.size(q::LocationScale) = size(q.q_trans) +Base.length(q::VILocationScale) = length(q.location) +Base.size(q::VILocationScale) = size(q.location) -function StatsBase.entropy(q::LocationScale) - q_base = q.q_trans.dist - scale = q.q_trans.transform.inner.a - entropy(q_base) + first(logabsdet(scale)) +function StatsBase.entropy(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + n_dims*entropy(dist) + first(logabsdet(scale)) end +function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) +end -Distributions.logpdf(q::LocationScale, z::AbstractVector) = logpdf(q.q_trans, z) - -_logpdf(q::LocationScale, y::AbstractVector) = _logpdf(q.q_trans, y) +function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) +end -rand(q::LocationScale) = rand(q.q_trans) +function rand(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims) + location +end -rand(rng::Random.AbstractRNG, q::LocationScale, num_samples::Int) = rand(rng, q.q_trans, num_samples) +function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims, num_samples) .+ location +end -_rand!(rng::Random.AbstractRNG, q::LocationScale, x::AbstractVector{<:Real}) = _rand!(rng, q.q_trans, x) +function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x .= scale*x + return x += location +end +function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x *= scale + return x += location +end -function FullRankGaussian(μ::AbstractVector{T}, - L::AbstractTriangular{T,S}) where {T <: Real, S} - @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) - LocationScale(μ, L, q_base) +function VIFullRankGaussian(μ::AbstractVector{T}, + L::AbstractTriangular{T}, + epsilon::Real = eps(T)) where {T <: Real} + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base, epsilon) end -function MeanFieldGaussian(μ::AbstractVector{T}, - L::Diagonal{T,V}) where {T <: Real, V} - @assert (length(μ) == size(L,1)) - n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) - LocationScale(μ, L, q_base) +function VIMeanFieldGaussian(μ::AbstractVector{T}, + L::Diagonal{T}, + epsilon::Real = eps(T)) where {T <: Real} + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base, epsilon) end From a034ebdec0e42d63211fe8e1c23d4b4e714a30bb Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 00:50:33 +0100 Subject: [PATCH 032/206] fix RNG namespace --- src/objectives/elbo/advi.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index dc2962ee..311a94f3 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -30,9 +30,9 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) end function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; - rng ::Random.AbstractRNG = Random.default_rng(), - n_samples ::Int = advi.n_samples, - ηs ::AbstractMatrix = rand(rng, q_η, n_samples), + rng ::AbstractRNG = default_rng(), + n_samples ::Int = advi.n_samples, + ηs ::AbstractMatrix = rand(rng, q_η, n_samples), q_η_entropy::ContinuousMultivariateDistribution = q_η) 𝔼ℓ = advi.energy_estimator(q_η, ηs) ℍ = advi.entropy_estimator(q_η_entropy, ηs) @@ -40,7 +40,7 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; end function estimate_gradient( - rng::Random.AbstractRNG, + rng::AbstractRNG, advi::ADVI, est_state, λ::Vector{<:Real}, From e19abd3d06291090f45b4b8b118e7be3003343c5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:08:46 +0100 Subject: [PATCH 033/206] fix location scale logpdf bug --- src/distributions/location_scale.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index c46b5111..c1803ffe 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -23,19 +23,19 @@ Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) function StatsBase.entropy(q::VILocationScale) - @unpack location, scale, dist = q + @unpack location, scale, dist = q n_dims = length(location) n_dims*entropy(dist) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) From 680c1864ecfe2a2867e9f48fe4bbf1ca37065aa3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:12:19 +0100 Subject: [PATCH 034/206] add Accessors dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index d2708915..add1e391 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.3" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" From 4c6cabf688af0552a307c22b821901cc792676be Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:12:44 +0100 Subject: [PATCH 035/206] add location scale, autodiff tests --- test/ad.jl | 22 +++++++++++++++++++++ test/distributions.jl | 45 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 32 ++++++++---------------------- 3 files changed, 75 insertions(+), 24 deletions(-) create mode 100644 test/ad.jl create mode 100644 test/distributions.jl diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 00000000..c084165c --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,22 @@ + +using ReTest +using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote +using AdvancedVI: grad! + +@testset "ad" begin + @testset "$(string(adsymbol))" for adsymbol ∈ [ + :forwarddiff, :reversediff, :tracker, :enzyme, :zygote] + D = 10 + A = randn(D, D) + λ = randn(D) + AdvancedVI.setadbackend(adsymbol) + grad_buf = DiffResults.GradientResult(λ) + AdvancedVI.grad!(AdvancedVI.ADBackend(), λ, grad_buf) do λ′ + λ′'*A*λ′ / 2 + end + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A')*λ/2 + @test f ≈ λ'*A*λ / 2 + end +end diff --git a/test/distributions.jl b/test/distributions.jl new file mode 100644 index 00000000..ab9617aa --- /dev/null +++ b/test/distributions.jl @@ -0,0 +1,45 @@ + +using ReTest +using Distributions +using Distributions: _logpdf +using LinearAlgebra +using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian + +@testset "distributions" begin + @testset "$(string(covtype)) Gaussian $(realtype)" for + covtype = [:diagonal, :fullrank], + realtype = [Float32, Float64] + + realtype = Float64 + ϵ = 1e-2 + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(realtype, n_dims) + L₀ = randn(realtype, n_dims, n_dims) + Σ = if covtype == :fullrank + Σ = (L₀*L₀' + ϵ*I) |> Hermitian + else + Diagonal(exp.(randn(realtype, n_dims))) + end + + L = cholesky(Σ).L + q = if covtype == :fullrank + VIFullRankGaussian(μ, L |> LowerTriangular) + else + VIMeanFieldGaussian(μ, L |> Diagonal) + end + q_true = MvNormal(μ, Σ) + + z = randn(n_dims) + @test logpdf(q, z) ≈ logpdf(q_true, z) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) + @test entropy(q) ≈ entropy(q_true) + + z_samples = rand(q, n_montecarlo) + threesigma = L + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a305c25e..44074197 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,28 +1,12 @@ -using Test -using Distributions, DistributionsAD -using AdvancedVI -include("optimisers.jl") +using ReTest: @testset, @test +#using Random +#using Statistics +#using Distributions, DistributionsAD -target = MvNormal(ones(2)) -logπ(z) = logpdf(target, z) -advi = ADVI(10, 1000) +println("Environment variables for testing") +println(ENV) -# Using a function z ↦ q(⋅∣z) -getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -q = vi(logπ, advi, getq, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 - -# OR: implement `update` and pass a `Distribution` -function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) -end - -q0 = TuringDiagMvNormal(zeros(2), ones(2)) -q = vi(logπ, advi, q0, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +include("ad.jl") +include("distributions.jl") From 06db2f02233e8e4e6010be6473ea7f356742a4a3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:15:03 +0100 Subject: [PATCH 036/206] add Accessors import statement --- src/AdvancedVI.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 76c6d859..5800cd93 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,7 @@ module AdvancedVI -using UnPack +using UnPack, Accessors import Random: AbstractRNG, default_rng import Distributions: logpdf, _logpdf, rand, _rand!, _rand! @@ -179,6 +179,7 @@ include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") # Variational Families + include("distributions/location_scale.jl") include("utils.jl") From 12de2bda787624b862772fc0b4fa55729ebb6ff9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 20:12:48 +0100 Subject: [PATCH 037/206] remove optimiser tests --- test/optimisers.jl | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 test/optimisers.jl diff --git a/test/optimisers.jl b/test/optimisers.jl deleted file mode 100644 index fae652ed..00000000 --- a/test/optimisers.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Random, Test, LinearAlgebra, ForwardDiff -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! - -θ = randn(10, 10) -@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 - x = rand(10) - Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end - From bbb2cc649fce6caddb751d0e5743d2fc2a814ad2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 20:12:59 +0100 Subject: [PATCH 038/206] refactor slightly generalize the distribution tests for the future --- test/distributions.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/distributions.jl b/test/distributions.jl index ab9617aa..07b3efdf 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -6,8 +6,9 @@ using LinearAlgebra using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian @testset "distributions" begin - @testset "$(string(covtype)) Gaussian $(realtype)" for - covtype = [:diagonal, :fullrank], + @testset "$(string(covtype)) $(basedist) $(realtype)" for + basedist = [:gaussian], + covtype = [:meanfield, :fullrank], realtype = [Float32, Float64] realtype = Float64 @@ -24,12 +25,14 @@ using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian end L = cholesky(Σ).L - q = if covtype == :fullrank + q = if covtype == :fullrank && basedist == :gaussian VIFullRankGaussian(μ, L |> LowerTriangular) - else + elseif covtype == :meanfield && basedist == :gaussian VIMeanFieldGaussian(μ, L |> Diagonal) end - q_true = MvNormal(μ, Σ) + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end z = randn(n_dims) @test logpdf(q, z) ≈ logpdf(q_true, z) From 197484655468ec5bab362380fb58d896a082b150 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:10:51 +0100 Subject: [PATCH 039/206] migrate to SimpleUnPack, migrate to ADTypes --- Project.toml | 3 +- src/AdvancedVI.jl | 150 ++++++++++---------------------------- src/ad.jl | 46 ------------ src/compat/enzyme.jl | 19 ++++- src/compat/reversediff.jl | 21 +++--- src/compat/zygote.jl | 16 +++- test/ad.jl | 14 ++-- 7 files changed, 90 insertions(+), 179 deletions(-) delete mode 100644 src/ad.jl diff --git a/Project.toml b/Project.toml index 93e3a52a..2fcc845e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.4" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -18,10 +19,10 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] Bijectors = "0.11, 0.12, 0.13" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5800cd93..573f7179 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,8 @@ module AdvancedVI -using UnPack, Accessors +using SimpleUnPack: @unpack +using Accessors import Random: AbstractRNG, default_rng import Distributions: logpdf, _logpdf, rand, _rand!, _rand! @@ -17,6 +18,8 @@ using LinearAlgebra: AbstractTriangular using LogDensityProblems +using ADTypes +using ADTypes: AbstractADType using ForwardDiff, Tracker using FillArrays @@ -30,78 +33,19 @@ using StatsBase: entropy const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) -include("ad.jl") - using Requires function __init__() @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ZygoteAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - y, back = Zygote.pullback(f, λ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ReverseDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - tp = AdvancedVI.tape(f, λ) - ReverseDiff.gradient!(out, tp, λ) - return out - end end @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("compat/enzyme.jl") - export EnzymeAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:EnzymeAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(λ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) - return out - end end end -export - optimize, - ELBO, - ADVI, - ADVIEnergy, - ClosedFormEntropy, - MonteCarloEntropy, - LocationScale, - FullRankGaussian, - MeanFieldGaussian, - TruncatedADAGrad, - DecayedADAGrad - - """ grad!(f, λ, out) @@ -111,55 +55,7 @@ This implicitly also gives a default implementation of `optimize!`. """ function grad! end -""" - optimize(model, alg::VariationalInference) - optimize(model, alg::VariationalInference, q::VariationalPosterior) - optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) - -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. - -# Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior -""" -function optimize end - -function update end - -# default implementations -function grad!( - f::Function, - adtype::Type{<:ForwardDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(adtype) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, λ) - else - ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) - end - ForwardDiff.gradient!(out, f, λ, config) -end - -function grad!( - f::Function, - ::Type{<:TrackerAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - λ_tracked = Tracker.param(λ) - y = f(λ_tracked) - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(λ_tracked)) -end +include("grad.jl") # estimators abstract type AbstractVariationalObjective end @@ -170,6 +66,9 @@ abstract type AbstractEnergyEstimator end abstract type AbstractEntropyEstimator end abstract type AbstractControlVariate end +function init end +function update end + init(::Nothing) = nothing update(::Nothing, ::Nothing) = (nothing, nothing) @@ -178,11 +77,42 @@ include("objectives/elbo/advi.jl") include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") +export + ELBO, + ADVI, + ADVIEnergy, + ClosedFormEntropy, + MonteCarloEntropy + # Variational Families include("distributions/location_scale.jl") +export + VIFullRankGaussian, + VIMeanFieldGaussian + +""" + optimize(model, alg::VariationalInference) + optimize(model, alg::VariationalInference, q::VariationalPosterior) + optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) + +Constructs the variational posterior from the `model` and performs the optimization +following the configuration of the given `VariationalInference` instance. + +# Arguments +- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations +- `alg`: the VI algorithm used +- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. +- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` +- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior +""" +function optimize end + +include("optimize.jl") + +export optimize + include("utils.jl") -include("vi.jl") end # module diff --git a/src/ad.jl b/src/ad.jl deleted file mode 100644 index 62e785e1..00000000 --- a/src/ad.jl +++ /dev/null @@ -1,46 +0,0 @@ -############################## -# Global variables/constants # -############################## -const ADBACKEND = Ref(:forwarddiff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end -function setadbackend(::Val{:forwarddiff}) - ADBACKEND[] = :forwarddiff -end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) - ADBACKEND[] = :tracker -end - -const ADSAFE = Ref(false) -function setadsafe(switch::Bool) - @info("[AdvancedVI]: global ADSAFE is set as $switch") - ADSAFE[] = switch -end - -const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically - -function setchunksize(chunk_size::Int) - @info("[AdvancedVI]: AD chunk size is set as $chunk_size") - CHUNKSIZE[] = chunk_size -end - -abstract type ADBackend end -struct ForwardDiffAD{chunk} <: ADBackend end -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk - -struct TrackerAD <: ADBackend end - -ADBackend() = ADBackend(ADBACKEND[]) -ADBackend(T::Symbol) = ADBackend(Val(T)) - -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl index c6bb9ac3..cab50862 100644 --- a/src/compat/enzyme.jl +++ b/src/compat/enzyme.jl @@ -1,5 +1,16 @@ -struct EnzymeAD <: ADBackend end -ADBackend(::Val{:enzyme}) = EnzymeAD -function setadbackend(::Val{:enzyme}) - ADBACKEND[] = :enzyme + +function AdvancedVI.grad!( + f::Function, + ::AutoEnzyme, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + # Use `Enzyme.ReverseWithPrimal` once it is released: + # https://github.com/EnzymeAD/Enzyme.jl/pull/598 + y = f(λ) + DiffResults.value!(out, y) + dy = DiffResults.gradient(out) + fill!(dy, 0) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) + return out end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 721d0361..4d8f87d8 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,16 +1,19 @@ using .ReverseDiff: compile, GradientTape using .ReverseDiff.DiffResults: GradientResult -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setcache(b::Bool) = RDCache[] = b -getcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()} -function setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - tape(f, x) = GradientTape(f, x) function taperesult(f, x) return tape(f, x), GradientResult(x) end + +# Precompiled tapes are not properly supported yet. +function AdvancedVI.grad!( + f::Function, + ::AutoReverseDiff, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + tp = tape(f, λ) + ReverseDiff.gradient!(out, tp, λ) + return out +end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 40022e21..f1a29b87 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -1,5 +1,13 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote + +function AdvancedVI.grad!( + f::Function, + ::AutoZygote, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + y, back = Zygote.pullback(f, λ) + dy = first(back(1.0)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, dy) + return out end diff --git a/test/ad.jl b/test/ad.jl index c084165c..6b587598 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,17 +1,21 @@ using ReTest using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote -using AdvancedVI: grad! +using ADTypes @testset "ad" begin - @testset "$(string(adsymbol))" for adsymbol ∈ [ - :forwarddiff, :reversediff, :tracker, :enzyme, :zygote] + @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + :ForwardDiffAuto => AutoForwardDiff(), + :ForwardDiff => AutoForwardDiff(10), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Tracker => AutoTracker(), + ) D = 10 A = randn(D, D) λ = randn(D) - AdvancedVI.setadbackend(adsymbol) grad_buf = DiffResults.GradientResult(λ) - AdvancedVI.grad!(AdvancedVI.ADBackend(), λ, grad_buf) do λ′ + AdvancedVI.grad!(adsymbol, λ, grad_buf) do λ′ λ′'*A*λ′ / 2 end ∇ = DiffResults.gradient(grad_buf) From 19c62c888fafbed9271e66cf1c7ced7b11a90457 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:11:12 +0100 Subject: [PATCH 040/206] rename vi.jl to optimize.jl --- src/{vi.jl => optimize.jl} | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) rename src/{vi.jl => optimize.jl} (89%) diff --git a/src/vi.jl b/src/optimize.jl similarity index 89% rename from src/vi.jl rename to src/optimize.jl index 842f187e..07184900 100644 --- a/src/vi.jl +++ b/src/optimize.jl @@ -19,6 +19,7 @@ function optimize( progress ::Bool = true, callback! = nothing, terminate = (args...) -> false, + adback::AbstractADType = AutoForwardDiff(), ) opt_state = Optimisers.init(optimizer, λ) est_state = init(objective) @@ -33,7 +34,8 @@ function optimize( for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, est_state, stat′ = estimate_gradient(rng, objective, est_state, λ, restructure, grad_buf) + grad_buf, est_state, stat′ = estimate_gradient( + rng, adback, objective, est_state, λ, restructure, grad_buf) g = DiffResults.gradient(grad_buf) stat = merge(stat, stat′) @@ -51,6 +53,9 @@ function optimize( AdvancedVI.DEBUG && @debug "Step $t" stat... + q = project_domain(q) + λ, _ = Optimisers.destructure(q) + pm_next!(prog, stat) stats[t] = stat From 63da51de8870575971b8e70e28dfc6c2265c5e30 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:11:25 +0100 Subject: [PATCH 041/206] fix estimate_gradient to use adtypes --- src/objectives/elbo/advi.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 311a94f3..ed834273 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -41,6 +41,7 @@ end function estimate_gradient( rng::AbstractRNG, + adback::AbstractADType, advi::ADVI, est_state, λ::Vector{<:Real}, @@ -50,7 +51,7 @@ function estimate_gradient( # Gradient-stopping for computing the sticking-the-landing control variate q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing - grad!(ADBackend(), λ, out) do λ′ + grad!(adback, λ, out) do λ′ q_η = restructure(λ′) q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η -advi(q_η; rng, q_η_entropy) From 65ab47395fa4fe88b6b65323325c68b5c0ee078a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:17:20 +0100 Subject: [PATCH 042/206] add exact inference tests --- test/distributions.jl | 5 +-- test/exact.jl | 64 +++++++++++++++++++++++++++++++++++ test/exact/normallognormal.jl | 52 ++++++++++++++++++++++++++++ test/runtests.jl | 13 +++---- 4 files changed, 124 insertions(+), 10 deletions(-) create mode 100644 test/exact.jl create mode 100644 test/exact/normallognormal.jl diff --git a/test/distributions.jl b/test/distributions.jl index 07b3efdf..074cad7c 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,9 +1,6 @@ using ReTest -using Distributions using Distributions: _logpdf -using LinearAlgebra -using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian @testset "distributions" begin @testset "$(string(covtype)) $(basedist) $(realtype)" for @@ -17,7 +14,7 @@ using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian n_montecarlo = 1000_000 μ = randn(realtype, n_dims) - L₀ = randn(realtype, n_dims, n_dims) + L₀ = randn(realtype, n_dims, n_dims) |> LowerTriangular Σ = if covtype == :fullrank Σ = (L₀*L₀' + ϵ*I) |> Hermitian else diff --git a/test/exact.jl b/test/exact.jl new file mode 100644 index 00000000..27b92c04 --- /dev/null +++ b/test/exact.jl @@ -0,0 +1,64 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using ReTest +using Turing, LogDensityProblems +using Optimisers +using Distributions +using LinearAlgebra +using SimpleUnPack: @unpack + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("inference/normallognormal.jl") + +@testset "exact" begin + @testset "$(modelname) $(realtype)" for + realtype ∈ [Float32, Float64], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ) + + T = 10000 + modelstats = modelconstr(realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + prob = DynamicPPL.LogDensityFunction(model) + + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + ones(realtype, n_dims) |> Diagonal + else + diagm(ones(realtype, n_dims)) |> LowerTriangular + end + q₀ = if is_meanfield + AdvancedVI.VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) + else + AdvancedVI.VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) + end + + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + + objective = AdvancedVI.ADVI(prob, b⁻¹, 10) + q, stats = AdvancedVI.optimize( + objective, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/√T + end +end diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl new file mode 100644 index 00000000..4e9e1404 --- /dev/null +++ b/test/exact/normallognormal.jl @@ -0,0 +1,52 @@ + +function normallognormal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = π + μ_y = randn(rng, realtype, n_dims) + L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular + ϵ = realtype(1.0) + Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian + + Turing.@model function normallognormal() + x ~ LogNormal(μ_x, σ_x) + y ~ MvNormal(μ_y, Σ_y) + end + model = normallognormal() + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L |> LowerTriangular + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(realtype) + n_dims = 5 + + μ_x = randn(realtype) + σ_x = π + μ_y = randn(realtype, n_dims) + ϵ = realtype(1.0) + Σ_y = Diagonal(exp.(randn(realtype, n_dims))) + + Turing.@model function normallognormal() + x ~ LogNormal(μ_x, σ_x) + y ~ MvNormal(μ_y, Σ_y) + end + model = normallognormal() + + σ² = Vector{realtype}(undef, n_dims+1) + σ²[1] = σ_x^2 + σ²[2:end] = diag(Σ_y) + + μ = vcat(μ_x, μ_y) + L = sqrt.(σ²) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 44074197..26f9a06f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,13 @@ using ReTest: @testset, @test -#using Random -#using Statistics -#using Distributions, DistributionsAD - -println("Environment variables for testing") -println(ENV) +using Random +using Random: default_rng +using Statistics +using Distributions, DistributionsAD +using LinearAlgebra +using AdvancedVI include("ad.jl") include("distributions.jl") +include("exact.jl") From 3e5a4520835f0d182b8f7c4aaef0529ff37498e6 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 00:28:18 +0100 Subject: [PATCH 043/206] remove Turing dependency in tests --- test/exact.jl | 9 ++++--- test/exact/normallognormal.jl | 47 +++++++++++++++++++++++------------ test/runtests.jl | 9 ++++++- 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/test/exact.jl b/test/exact.jl index 27b92c04..d5283e8e 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -2,9 +2,11 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using ReTest -using Turing, LogDensityProblems +using Bijectors +using LogDensityProblems using Optimisers using Distributions +using PDMats using LinearAlgebra using SimpleUnPack: @unpack @@ -16,7 +18,7 @@ struct TestModel{M,L,S} is_meanfield::Bool end -include("inference/normallognormal.jl") +include("exact/normallognormal.jl") @testset "exact" begin @testset "$(modelname) $(realtype)" for @@ -32,7 +34,6 @@ include("inference/normallognormal.jl") b = Bijectors.bijector(model) b⁻¹ = inverse(b) - prob = DynamicPPL.LogDensityFunction(model) μ₀ = zeros(realtype, n_dims) L₀ = if is_meanfield @@ -48,7 +49,7 @@ include("inference/normallognormal.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - objective = AdvancedVI.ADVI(prob, b⁻¹, 10) + objective = AdvancedVI.ADVI(model, b⁻¹, 10) q, stats = AdvancedVI.optimize( objective, q₀, T; optimizer = Optimisers.AdaGrad(1e-1), diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl index 4e9e1404..e39ec2cb 100644 --- a/test/exact/normallognormal.jl +++ b/test/exact/normallognormal.jl @@ -1,4 +1,31 @@ +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + function normallognormal_fullrank(realtype; rng = default_rng()) n_dims = 5 @@ -9,11 +36,7 @@ function normallognormal_fullrank(realtype; rng = default_rng()) ϵ = realtype(1.0) Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian - Turing.@model function normallognormal() - x ~ LogNormal(μ_x, σ_x) - y ~ MvNormal(μ_y, Σ_y) - end - model = normallognormal() + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) Σ[1,1] = σ_x^2 @@ -33,20 +56,12 @@ function normallognormal_meanfield(realtype) σ_x = π μ_y = randn(realtype, n_dims) ϵ = realtype(1.0) - Σ_y = Diagonal(exp.(randn(realtype, n_dims))) - - Turing.@model function normallognormal() - x ~ LogNormal(μ_x, σ_x) - y ~ MvNormal(μ_y, Σ_y) - end - model = normallognormal() + σ_y = exp.(randn(realtype, n_dims)) - σ² = Vector{realtype}(undef, n_dims+1) - σ²[1] = σ_x^2 - σ²[2:end] = diag(Σ_y) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) μ = vcat(μ_x, μ_y) - L = sqrt.(σ²) |> Diagonal + L = vcat(σ_x, σ_y) |> Diagonal TestModel(model, μ, L, n_dims+1, true) end diff --git a/test/runtests.jl b/test/runtests.jl index 26f9a06f..0b86222b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,20 @@ +using Comonicon using ReTest: @testset, @test using Random using Random: default_rng using Statistics -using Distributions, DistributionsAD +using Distributions using LinearAlgebra using AdvancedVI +const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC") + include("ad.jl") include("distributions.jl") include("exact.jl") +@main function runtests(patterns...; dry::Bool = false) + retest(patterns...; dry = dry, verbose = Inf) +end + From 3117cec8952b80b58e205726f2abe9f77ffddf80 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 02:44:22 +0100 Subject: [PATCH 044/206] remove unused projection --- src/optimize.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 07184900..2acfbc0b 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -43,7 +43,6 @@ function optimize( Optimisers.subtract!(λ, Δλ) stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) stat = merge(stat, stat′) - q = restructure(λ) if !isnothing(callback!) @@ -53,9 +52,6 @@ function optimize( AdvancedVI.DEBUG && @debug "Step $t" stat... - q = project_domain(q) - λ, _ = Optimisers.destructure(q) - pm_next!(prog, stat) stats[t] = stat From b1ca9cf5cfad2345c92481c7519b12e1520776ef Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:03:57 +0100 Subject: [PATCH 045/206] remove redundant `ADVIEnergy` object (now baked into `ADVI`) --- src/AdvancedVI.jl | 2 +- src/objectives/elbo/advi.jl | 38 ++++++++++++++++++++---------- src/objectives/elbo/advi_energy.jl | 37 ----------------------------- 3 files changed, 26 insertions(+), 51 deletions(-) delete mode 100644 src/objectives/elbo/advi_energy.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 573f7179..502112c7 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -74,7 +74,6 @@ init(::Nothing) = nothing update(::Nothing, ::Nothing) = (nothing, nothing) include("objectives/elbo/advi.jl") -include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") export @@ -82,6 +81,7 @@ export ADVI, ADVIEnergy, ClosedFormEntropy, + StickingTheLandingEntropy, MonteCarloEntropy # Variational Families diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index ed834273..9cd2433e 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,32 +1,41 @@ -struct ADVI{EnergyEst <: AbstractEnergyEstimator, +struct ADVI{Tlogπ, B, EntropyEst <: AbstractEntropyEstimator, ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective - energy_estimator::EnergyEst + ℓπ::Tlogπ + b⁻¹::B entropy_estimator::EntropyEst control_variate::ControlVar n_samples::Int + + function ADVI(prob, b⁻¹, entropy_estimator, control_variate, n_samples) + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + new{typeof(ℓπ), typeof(b⁻¹), typeof(entropy_estimator), typeof(control_variate)}( + ℓπ, b⁻¹, entropy_estimator, control_variate, n_samples + ) + end end skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) init(advi::ADVI) = init(advi.control_variate) -Base.show(io::IO, advi::ADVI) = print( - io, - "ADVI(energy_estimator=$(advi.energy_estimator), " * - "entropy_estimator=$(advi.entropy_estimator), " * - "control_variate=$(advi.control_variate), " * - "n_samples=$(advi.n_samples))") - -function ADVI(energy_estimator::AbstractEnergyEstimator, +function ADVI(ℓπ, b⁻¹, entropy_estimator::AbstractEntropyEstimator, n_samples::Int) - ADVI(energy_estimator, entropy_estimator, nothing, n_samples) + ADVI(ℓπ, b⁻¹, entropy_estimator, nothing, n_samples) end function ADVI(ℓπ, b⁻¹, n_samples::Int) - ADVI(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) + ADVI(ℓπ, b⁻¹, ClosedFormEntropy(), nothing, n_samples) end function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; @@ -34,7 +43,10 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; n_samples ::Int = advi.n_samples, ηs ::AbstractMatrix = rand(rng, q_η, n_samples), q_η_entropy::ContinuousMultivariateDistribution = q_η) - 𝔼ℓ = advi.energy_estimator(q_η, ηs) + 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b⁻¹, ηᵢ) + (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples + end ℍ = advi.entropy_estimator(q_η_entropy, ηs) 𝔼ℓ + ℍ end diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl deleted file mode 100644 index 078a157e..00000000 --- a/src/objectives/elbo/advi_energy.jl +++ /dev/null @@ -1,37 +0,0 @@ - -struct ADVIEnergy{Tlogπ, B} <: AbstractEnergyEstimator - # Automatic differentiation variational inference - # - # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). - # Automatic differentiation variational inference. - # Journal of machine learning research. - - ℓπ::Tlogπ - b⁻¹::B - - function ADVIEnergy(prob, b⁻¹) - # Could check whether the support of b⁻¹ and ℓπ match - cap = LogDensityProblems.capabilities(prob) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹) - end -end - -ADVIEnergy(prob) = ADVIEnergy(prob, identity) - -Base.show(io::IO, energy::ADVIEnergy) = print(io, "ADVIEnergy()") - -function (energy::ADVIEnergy)(q, ηs::AbstractMatrix) - n_samples = size(ηs, 2) - mapreduce(+, eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(energy.b⁻¹, ηᵢ) - (energy.ℓπ(zᵢ) + logdetjacᵢ) / n_samples - end -end From fcbb729378e3e4e16e6288a9336511f2b616b557 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:04:21 +0100 Subject: [PATCH 046/206] add more tests, fix rng seed for tests --- test/exact.jl | 69 +++++++++++++++++++++++++++-------- test/exact/normallognormal.jl | 15 ++++---- test/runtests.jl | 2 +- 3 files changed, 61 insertions(+), 25 deletions(-) diff --git a/test/exact.jl b/test/exact.jl index d5283e8e..637a95ed 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -21,15 +21,22 @@ end include("exact/normallognormal.jl") @testset "exact" begin - @testset "$(modelname) $(realtype)" for + @testset "$(modelname) $(objname) $(realtype)" for realtype ∈ [Float32, Float64], (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), + :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), ) - + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + T = 10000 - modelstats = modelconstr(realtype) + modelstats = modelconstr(realtype; rng) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats b = Bijectors.bijector(model) @@ -42,24 +49,54 @@ include("exact/normallognormal.jl") diagm(ones(realtype, n_dims)) |> LowerTriangular end q₀ = if is_meanfield - AdvancedVI.VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) + VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) else - AdvancedVI.VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) + VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) end - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + obj = objective(model, b⁻¹, 10) - objective = AdvancedVI.ADVI(model, b⁻¹, 10) - q, stats = AdvancedVI.optimize( - objective, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - ) + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-0), + progress = PROGRESS, + rng = rng, + ) - μ = q.location - L = q.scale - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/√T + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end - @test Δλ ≤ Δλ₀/√T + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-2), + progress = PROGRESS, + rng = rng, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-2), + progress = PROGRESS, + rng = rng_repl, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end end end + diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl index e39ec2cb..7c5c000d 100644 --- a/test/exact/normallognormal.jl +++ b/test/exact/normallognormal.jl @@ -30,10 +30,10 @@ function normallognormal_fullrank(realtype; rng = default_rng()) n_dims = 5 μ_x = randn(rng, realtype) - σ_x = π + σ_x = ℯ μ_y = randn(rng, realtype, n_dims) L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(1.0) + ϵ = realtype(n_dims) Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) @@ -49,14 +49,13 @@ function normallognormal_fullrank(realtype; rng = default_rng()) TestModel(model, μ, L, n_dims+1, false) end -function normallognormal_meanfield(realtype) +function normallognormal_meanfield(realtype; rng = default_rng()) n_dims = 5 - μ_x = randn(realtype) - σ_x = π - μ_y = randn(realtype, n_dims) - ϵ = realtype(1.0) - σ_y = exp.(randn(realtype, n_dims)) + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) diff --git a/test/runtests.jl b/test/runtests.jl index 0b86222b..b571f8b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Comonicon using ReTest: @testset, @test using Random -using Random: default_rng +using Random123 using Statistics using Distributions using LinearAlgebra From 0f6f6a429ba74e491943ad96fa52ff9f897cc862 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:04:35 +0100 Subject: [PATCH 047/206] add more tests, fix seed for tests --- test/distributions.jl | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/test/distributions.jl b/test/distributions.jl index 074cad7c..073fff64 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -8,17 +8,19 @@ using Distributions: _logpdf covtype = [:meanfield, :fullrank], realtype = [Float32, Float64] + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) realtype = Float64 ϵ = 1e-2 n_dims = 10 n_montecarlo = 1000_000 - μ = randn(realtype, n_dims) - L₀ = randn(realtype, n_dims, n_dims) |> LowerTriangular + μ = randn(rng, realtype, n_dims) + L₀ = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular Σ = if covtype == :fullrank Σ = (L₀*L₀' + ϵ*I) |> Hermitian else - Diagonal(exp.(randn(realtype, n_dims))) + Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) end L = cholesky(Σ).L @@ -31,15 +33,26 @@ using Distributions: _logpdf MvNormal(μ, Σ) end - z = randn(n_dims) - @test logpdf(q, z) ≈ logpdf(q_true, z) - @test _logpdf(q, z) ≈ _logpdf(q_true, z) - @test entropy(q) ≈ entropy(q_true) + @testset "logpdf" begin + z = randn(rng, realtype, n_dims) + @test logpdf(q, z) ≈ logpdf(q_true, z) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) + @test eltype(logpdf(q, z)) == realtype + @test eltype(_logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end - z_samples = rand(q, n_montecarlo) - threesigma = L - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @testset "sampling" begin + z_samples = rand(rng, q, n_montecarlo) + threesigma = L + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end end end From f5f5863b55af07ea1009528e5b8e1fdb1bfc96df Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:16:49 +0100 Subject: [PATCH 048/206] fix non-determinism bug --- src/distributions/location_scale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index c1803ffe..e9e8c743 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -47,7 +47,7 @@ end function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) @unpack location, scale, dist = q n_dims = length(location) - scale*rand(dist, n_dims, num_samples) .+ location + scale*rand(rng, dist, n_dims, num_samples) .+ location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) From ade0d1007c1507fb0359d744fa640349314e325d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 04:56:29 +0100 Subject: [PATCH 049/206] fix test hyperparameters so that tests pass, minor cleanups --- src/distributions/location_scale.jl | 12 ++++++++++++ src/objectives/elbo/advi.jl | 6 ++++++ src/optimize.jl | 6 ++++-- test/exact.jl | 10 +++++----- test/exact/normallognormal.jl | 2 +- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index e9e8c743..dc9c1b27 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,4 +1,16 @@ +""" + +The [location scale] variational family broadly represents various variational +families using `location` and `scale` variational parameters. + +Multivariate Student-t variational family with ``\\nu``-degrees of freedom can +be constructed as: +```julia +q₀ = VILocationScale(μ, L, StudentT(ν), eps(Float32)) +``` + +""" struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution location::L scale ::S diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 9cd2433e..b9b1185f 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -24,6 +24,12 @@ struct ADVI{Tlogπ, B, end end +Base.show(io::IO, advi::ADVI) = + print(io, + "ADVI(entropy_estimator=$(advi.entropy_estimator), " * + "control_variate=$(advi.control_variate), " * + "n_samples=$(advi.n_samples))") + skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) init(advi::ADVI) = init(advi.control_variate) diff --git a/src/optimize.jl b/src/optimize.jl index 2acfbc0b..dcd1c439 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -41,9 +41,11 @@ function optimize( opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g) Optimisers.subtract!(λ, Δλ) + stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) stat = merge(stat, stat′) - q = restructure(λ) + + q = restructure(λ) if !isnothing(callback!) stat′ = callback!(q, stat) @@ -56,7 +58,7 @@ function optimize( stats[t] = stat # Termination decision is work in progress - if terminate(rng, q, objective, stat) + if terminate(rng, λ, q, objective, stat) stats = stats[1:t] break end diff --git a/test/exact.jl b/test/exact.jl index 637a95ed..d1be4626 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -49,9 +49,9 @@ include("exact/normallognormal.jl") diagm(ones(realtype, n_dims)) |> LowerTriangular end q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) + VIMeanFieldGaussian(μ₀, L₀) else - VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) + VIFullRankGaussian(μ₀, L₀) end obj = objective(model, b⁻¹, 10) @@ -60,7 +60,7 @@ include("exact/normallognormal.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-0), + optimizer = Optimisers.AdaGrad(1e-1), progress = PROGRESS, rng = rng, ) @@ -78,7 +78,7 @@ include("exact/normallognormal.jl") rng = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-2), + optimizer = Optimisers.AdaGrad(1e-1), progress = PROGRESS, rng = rng, ) @@ -88,7 +88,7 @@ include("exact/normallognormal.jl") rng_repl = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-2), + optimizer = Optimisers.AdaGrad(1e-1), progress = PROGRESS, rng = rng_repl, ) diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl index 7c5c000d..18e8b4a3 100644 --- a/test/exact/normallognormal.jl +++ b/test/exact/normallognormal.jl @@ -33,7 +33,7 @@ function normallognormal_fullrank(realtype; rng = default_rng()) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(n_dims) + ϵ = realtype(n_dims*2) Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) From 0caf7a9ef768ce97c7498c981d5ef60ee673488f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 19:37:45 +0100 Subject: [PATCH 050/206] fix minor reorganization --- src/AdvancedVI.jl | 9 +-- test/exact.jl | 102 ---------------------------------- test/exact/normallognormal.jl | 66 ---------------------- test/runtests.jl | 4 +- 4 files changed, 4 insertions(+), 177 deletions(-) delete mode 100644 test/exact.jl delete mode 100644 test/exact/normallognormal.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 502112c7..86c9fc44 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,7 @@ module AdvancedVI -using SimpleUnPack: @unpack +using SimpleUnPack: @unpack, @pack! using Accessors import Random: AbstractRNG, default_rng @@ -60,17 +60,14 @@ include("grad.jl") # estimators abstract type AbstractVariationalObjective end +function init end function estimate_gradient end -abstract type AbstractEnergyEstimator end +# ADVI-specific interfaces abstract type AbstractEntropyEstimator end abstract type AbstractControlVariate end -function init end function update end - -init(::Nothing) = nothing - update(::Nothing, ::Nothing) = (nothing, nothing) include("objectives/elbo/advi.jl") diff --git a/test/exact.jl b/test/exact.jl deleted file mode 100644 index d1be4626..00000000 --- a/test/exact.jl +++ /dev/null @@ -1,102 +0,0 @@ - -const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false - -using ReTest -using Bijectors -using LogDensityProblems -using Optimisers -using Distributions -using PDMats -using LinearAlgebra -using SimpleUnPack: @unpack - -struct TestModel{M,L,S} - model::M - μ_true::L - L_true::S - n_dims::Int - is_meanfield::Bool -end - -include("exact/normallognormal.jl") - -@testset "exact" begin - @testset "$(modelname) $(objname) $(realtype)" for - realtype ∈ [Float32, Float64], - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - :NormalLogNormalFullRank => normallognormal_fullrank, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), - :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), - ) - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - T = 10000 - modelstats = modelconstr(realtype; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - - μ₀ = zeros(realtype, n_dims) - L₀ = if is_meanfield - ones(realtype, n_dims) |> Diagonal - else - diagm(ones(realtype, n_dims)) |> LowerTriangular - end - q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀) - else - VIFullRankGaussian(μ₀, L₀) - end - - obj = objective(model, b⁻¹, 10) - - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) - - μ = q.location - L = q.scale - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ Δλ₀/√T - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) - μ = q.location - L = q.scale - - rng_repl = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng_repl, - ) - μ_repl = q.location - L_repl = q.scale - @test μ == μ_repl - @test L == L_repl - end - end -end - diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl deleted file mode 100644 index 18e8b4a3..00000000 --- a/test/exact/normallognormal.jl +++ /dev/null @@ -1,66 +0,0 @@ - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -function normallognormal_fullrank(realtype; rng = default_rng()) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(n_dims*2) - Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) - - Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) - Σ[1,1] = σ_x^2 - Σ[2:end,2:end] = Σ_y - Σ = Σ |> Hermitian - - μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L |> LowerTriangular - - TestModel(model, μ, L, n_dims+1, false) -end - -function normallognormal_meanfield(realtype; rng = default_rng()) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) - - μ = vcat(μ_x, μ_y) - L = vcat(σ_x, σ_y) |> Diagonal - - TestModel(model, μ, L, n_dims+1, true) -end diff --git a/test/runtests.jl b/test/runtests.jl index b571f8b8..ddc1d09c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,11 +8,9 @@ using Distributions using LinearAlgebra using AdvancedVI -const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC") - include("ad.jl") include("distributions.jl") -include("exact.jl") +include("advi_locscale.jl") @main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) From 5658cbf10e3f6e64d7b03380d4c026951cb3f0c2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 19:40:59 +0100 Subject: [PATCH 051/206] add missing files --- test/Project.toml | 20 +++++++ test/advi_locscale.jl | 102 +++++++++++++++++++++++++++++++++ test/models/normallognormal.jl | 66 +++++++++++++++++++++ 3 files changed, 188 insertions(+) create mode 100644 test/Project.toml create mode 100644 test/advi_locscale.jl create mode 100644 test/models/normallognormal.jl diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..2f38c88f --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,20 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" +ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl new file mode 100644 index 00000000..2beb0547 --- /dev/null +++ b/test/advi_locscale.jl @@ -0,0 +1,102 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using ReTest +using Bijectors +using LogDensityProblems +using Optimisers +using Distributions +using PDMats +using LinearAlgebra +using SimpleUnPack: @unpack + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("models/normallognormal.jl") + +@testset "exact" begin + @testset "$(modelname) $(objname) $(realtype)" for + realtype ∈ [Float32, Float64], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), + :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), + ) + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + T = 10000 + modelstats = modelconstr(realtype; rng) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + ones(realtype, n_dims) |> Diagonal + else + diagm(ones(realtype, n_dims)) |> LowerTriangular + end + q₀ = if is_meanfield + VIMeanFieldGaussian(μ₀, L₀) + else + VIFullRankGaussian(μ₀, L₀) + end + + obj = objective(model, b⁻¹, 10) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/√T + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng_repl, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + end +end + diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl new file mode 100644 index 00000000..18e8b4a3 --- /dev/null +++ b/test/models/normallognormal.jl @@ -0,0 +1,66 @@ + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +function normallognormal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular + ϵ = realtype(n_dims*2) + Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L |> LowerTriangular + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + + μ = vcat(μ_x, μ_y) + L = vcat(σ_x, σ_y) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end From c712a9762afdbc60468953bfeab1ad076a6cc2f9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 19:51:14 +0100 Subject: [PATCH 052/206] fix add missing file, rename adbackend argument --- src/grad.jl | 30 ++++++++++++++++++++++++++++++ src/optimize.jl | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 src/grad.jl diff --git a/src/grad.jl b/src/grad.jl new file mode 100644 index 00000000..e68e1623 --- /dev/null +++ b/src/grad.jl @@ -0,0 +1,30 @@ + +# default implementations +function grad!( + f::Function, + adtype::AutoForwardDiff{chunksize}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) where {chunksize} + # Set chunk size and do ForwardMode. + config = if isnothing(chunksize) + ForwardDiff.GradientConfig(f, λ) + else + ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunksize)) + end + ForwardDiff.gradient!(out, f, λ, config) +end + +function grad!( + f::Function, + ::AutoTracker, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) + λ_tracked = Tracker.param(λ) + y = f(λ_tracked) + Tracker.back!(y, 1.0) + + DiffResults.value!(out, Tracker.data(y)) + DiffResults.gradient!(out, Tracker.grad(λ_tracked)) +end diff --git a/src/optimize.jl b/src/optimize.jl index dcd1c439..16995925 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -19,7 +19,7 @@ function optimize( progress ::Bool = true, callback! = nothing, terminate = (args...) -> false, - adback::AbstractADType = AutoForwardDiff(), + adbackend::AbstractADType = AutoForwardDiff(), ) opt_state = Optimisers.init(optimizer, λ) est_state = init(objective) From bee839d91399ce9cc2d776f907dd9197e14aa241 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 20:03:16 +0100 Subject: [PATCH 053/206] fix errors --- src/AdvancedVI.jl | 2 ++ src/objectives/elbo/advi.jl | 4 ++-- src/optimize.jl | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 86c9fc44..4010b1fe 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -63,6 +63,8 @@ abstract type AbstractVariationalObjective end function init end function estimate_gradient end +init(::Nothing) = nothing + # ADVI-specific interfaces abstract type AbstractEntropyEstimator end abstract type AbstractControlVariate end diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index b9b1185f..1fb6b0c6 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -59,7 +59,7 @@ end function estimate_gradient( rng::AbstractRNG, - adback::AbstractADType, + adbackend::AbstractADType, advi::ADVI, est_state, λ::Vector{<:Real}, @@ -69,7 +69,7 @@ function estimate_gradient( # Gradient-stopping for computing the sticking-the-landing control variate q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing - grad!(adback, λ, out) do λ′ + grad!(adbackend, λ, out) do λ′ q_η = restructure(λ′) q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η -advi(q_η; rng, q_η_entropy) diff --git a/src/optimize.jl b/src/optimize.jl index 16995925..8b36df04 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,7 +35,7 @@ function optimize( stat = (iteration=t,) grad_buf, est_state, stat′ = estimate_gradient( - rng, adback, objective, est_state, λ, restructure, grad_buf) + rng, adbackend, objective, est_state, λ, restructure, grad_buf) g = DiffResults.gradient(grad_buf) stat = merge(stat, stat′) From 913911ec74f835d566e2f19b0df16358a3fd055b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 20:03:23 +0100 Subject: [PATCH 054/206] rename test suite --- test/advi_locscale.jl | 149 +++++++++++++++++++++++------------------- 1 file changed, 80 insertions(+), 69 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 2beb0547..342b9db1 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -20,83 +20,94 @@ end include("models/normallognormal.jl") -@testset "exact" begin - @testset "$(modelname) $(objname) $(realtype)" for - realtype ∈ [Float32, Float64], - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - :NormalLogNormalFullRank => normallognormal_fullrank, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), - :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), - ) - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - T = 10000 - modelstats = modelconstr(realtype; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats +@testset "advi" begin + @testset "locscale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float32, Float64], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), + :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) - μ₀ = zeros(realtype, n_dims) - L₀ = if is_meanfield - ones(realtype, n_dims) |> Diagonal - else - diagm(ones(realtype, n_dims)) |> LowerTriangular - end - q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀) - else - VIFullRankGaussian(μ₀, L₀) - end + T = 10000 + modelstats = modelconstr(realtype; rng) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - obj = objective(model, b⁻¹, 10) + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + ones(realtype, n_dims) |> Diagonal + else + diagm(ones(realtype, n_dims)) |> LowerTriangular + end + q₀ = if is_meanfield + VIMeanFieldGaussian(μ₀, L₀) + else + VIFullRankGaussian(μ₀, L₀) + end - μ = q.location - L = q.scale - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + obj = objective(model, b⁻¹, 10) - @test Δλ ≤ Δλ₀/√T - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + adbackend = adbackend, + ) - @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) - μ = q.location - L = q.scale + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - rng_repl = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng_repl, - ) - μ_repl = q.location - L_repl = q.scale - @test μ == μ_repl - @test L == L_repl + @test Δλ ≤ Δλ₀/√T + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + adbackend = adbackend, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng_repl, + adbackend = adbackend, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end end end end - From d50cabb0f0b7b7fac8bfd79c43ef38196b2df8c9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 01:59:43 +0100 Subject: [PATCH 055/206] refactor renamed arguments for ADVI to be shorter --- Project.toml | 3 +- src/AdvancedVI.jl | 7 ++-- src/objectives/elbo/advi.jl | 59 +++++++++++++++++----------------- src/objectives/elbo/entropy.jl | 42 ++++++++++++++---------- test/ad.jl | 10 +++--- test/advi_locscale.jl | 18 +++++------ 6 files changed, 73 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index 2fcc845e..cf698f7a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -25,9 +24,9 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] +ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10.3" ProgressMeter = "1.0.0" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 4010b1fe..e3dd85a8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -5,7 +5,10 @@ using SimpleUnPack: @unpack, @pack! using Accessors import Random: AbstractRNG, default_rng -import Distributions: logpdf, _logpdf, rand, _rand!, _rand! +using Distributions +import Distributions: + logpdf, _logpdf, rand, _rand!, _rand!, + ContinuousMultivariateDistribution using Functors using Optimisers @@ -24,8 +27,6 @@ using ForwardDiff, Tracker using FillArrays using PDMats -using Distributions, DistributionsAD -using Distributions: ContinuousMultivariateDistribution using Bijectors using StatsBase diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 1fb6b0c6..e965ea73 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,14 +1,28 @@ +""" + ADVI + +Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- ``\\pi`` must be differentiable + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. +""" struct ADVI{Tlogπ, B, - EntropyEst <: AbstractEntropyEstimator, - ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective + EntropyEst <: AbstractEntropyEstimator, + ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective ℓπ::Tlogπ - b⁻¹::B - entropy_estimator::EntropyEst - control_variate::ControlVar + b::B + entropy::EntropyEst + cv::ControlVar n_samples::Int - function ADVI(prob, b⁻¹, entropy_estimator, control_variate, n_samples) + function ADVI(prob, n_samples::Int; + entropy::AbstractEntropyEstimator = ClosedFormEntropy(), + cv::Union{<:AbstractControlVariate, Nothing} = nothing, + b = Bijectors.identity) cap = LogDensityProblems.capabilities(prob) if cap === nothing throw( @@ -18,31 +32,16 @@ struct ADVI{Tlogπ, B, ) end ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b⁻¹), typeof(entropy_estimator), typeof(control_variate)}( - ℓπ, b⁻¹, entropy_estimator, control_variate, n_samples - ) + new{typeof(ℓπ), typeof(b), typeof(entropy), typeof(cv)}(ℓπ, b, entropy, cv, n_samples) end end Base.show(io::IO, advi::ADVI) = - print(io, - "ADVI(entropy_estimator=$(advi.entropy_estimator), " * - "control_variate=$(advi.control_variate), " * - "n_samples=$(advi.n_samples))") - -skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) + print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))") -init(advi::ADVI) = init(advi.control_variate) +skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy) -function ADVI(ℓπ, b⁻¹, - entropy_estimator::AbstractEntropyEstimator, - n_samples::Int) - ADVI(ℓπ, b⁻¹, entropy_estimator, nothing, n_samples) -end - -function ADVI(ℓπ, b⁻¹, n_samples::Int) - ADVI(ℓπ, b⁻¹, ClosedFormEntropy(), nothing, n_samples) -end +init(advi::ADVI) = init(advi.cv) function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; rng ::AbstractRNG = default_rng(), @@ -50,10 +49,10 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; ηs ::AbstractMatrix = rand(rng, q_η, n_samples), q_η_entropy::ContinuousMultivariateDistribution = q_η) 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b⁻¹, ηᵢ) + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples end - ℍ = advi.entropy_estimator(q_η_entropy, ηs) + ℍ = advi.entropy(q_η_entropy, ηs) 𝔼ℓ + ℍ end @@ -67,17 +66,17 @@ function estimate_gradient( out::DiffResults.MutableDiffResult) # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing + q_η_stop = skip_entropy_gradient(advi.entropy) ? restructure(λ) : nothing grad!(adbackend, λ, out) do λ′ q_η = restructure(λ′) - q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η + q_η_entropy = skip_entropy_gradient(advi.entropy) ? q_η_stop : q_η -advi(q_η; rng, q_η_entropy) end nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - est_state, stat′ = update(advi.control_variate, est_state) + est_state, stat′ = update(advi.cv, est_state) stat = !isnothing(stat′) ? merge(stat′, stat) : stat out, est_state, stat diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index ddeb64a9..994bdd4f 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -14,27 +14,35 @@ MonteCarloEntropy() = MonteCarloEntropy{false}() Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()") """ - Sticking the Landing Control Variate + StickingTheLandingEntropy() - # Explanation +# Explanation - This eatimator forms a control variate of the form of +The STL estimator forms a control variate of the form of - c(z) = 𝔼-logq(z) + logq(z) = ℍ[q] - logq(z) +```math +\\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) = + \\mathbb{E}\\left[ -\\log q\\left(z\\right) \\right] + + \\log q\\left(z\\right) = \\mathbb{H}\\left(q_{\\lambda}\\right) + \\log q_{\\lambda}\\left(z\\right), +``` +where, for the score term, the gradient is stopped from propagating. - Adding this to the closed-form entropy ELBO estimator yields: - - ELBO - c(z) = 𝔼logπ(z) + ℍ[q] - c(z) = 𝔼logπ(z) - logq(z), - - which has the same expectation, but lower variance when π ≈ q, - and higher variance when π ≉ q. - - # Reference - - Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud. - "Sticking the landing: Simple, lower-variance gradient estimators for - variational inference." - Advances in Neural Information Processing Systems 30 (2017). +Adding this to the closed-form entropy ELBO estimator yields the STL estimator: +```math +\\begin{aligned} + \\widehat{\\mathrm{ELBO}}_{\\mathrm{STL}}\\left(\\lambda\\right) + &\\triangleq \\mathbb{E}\\left[ \\log \\pi \\left(z\\right) \\right] - \\log q_{\\lambda} \\left(z\\right) \\\\ + &= \\mathbb{E}\\left[ \\log \\pi\\left(z\\right) \\right] + + \\mathbb{H}\\left(q_{\\lambda}\\right) - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) \\\\ + &= \\widehat{\\mathrm{ELBO}}\\left(\\lambda\\right) + - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right), +\\end{aligned} +``` +which has the same expectation, but lower variance when ``\\pi \\approx q_{\\lambda}``, +and higher variance when ``\\pi \\not\\approx q_{\\lambda}``. + +# Reference +1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. """ StickingTheLandingEntropy() = MonteCarloEntropy{true}() diff --git a/test/ad.jl b/test/ad.jl index 6b587598..1efa536b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -5,11 +5,11 @@ using ADTypes @testset "ad" begin @testset "$(adname)" for (adname, adsymbol) ∈ Dict( - :ForwardDiffAuto => AutoForwardDiff(), - :ForwardDiff => AutoForwardDiff(10), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Tracker => AutoTracker(), + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Tracker => AutoTracker(), + :Enzyme => AutoEnzyme(), ) D = 10 A = randn(D, D) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 342b9db1..dadbaf25 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -29,15 +29,15 @@ include("models/normallognormal.jl") :NormalLogNormalFullRank => normallognormal_fullrank, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), - :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), + :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), + :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, H = StickingTheLandingEntropy()), + :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, H = MonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), + # :ReverseDiff => AutoReverseDiff(), + # :Zygote => AutoZygote(), + # :Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) @@ -68,7 +68,7 @@ include("models/normallognormal.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), + optimizer = Optimisers.Adam(1e-3), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -87,7 +87,7 @@ include("models/normallognormal.jl") rng = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), + optimizer = Optimisers.Adam(1e-3), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -98,7 +98,7 @@ include("models/normallognormal.jl") rng_repl = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), + optimizer = Optimisers.Adam(1e-3), progress = PROGRESS, rng = rng_repl, adbackend = adbackend, From b134f7099062b2c7a6d7b3ec9e30867703c609da Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 02:07:08 +0100 Subject: [PATCH 056/206] fix compile error in advi test --- test/advi_locscale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index dadbaf25..40e5dace 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -30,8 +30,8 @@ include("models/normallognormal.jl") ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), - :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, H = StickingTheLandingEntropy()), - :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, H = MonteCarloEntropy()), + :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), + :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = MonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), From a6ba379b9a97e509076ce0c7e2c2ebd4b6caa737 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 22:33:52 +0100 Subject: [PATCH 057/206] add initial doc --- docs/make.jl | 17 +++++++++++ docs/src/advi.md | 67 ++++++++++++++++++++++++++++++++++++++++++++ docs/src/families.md | 58 ++++++++++++++++++++++++++++++++++++++ docs/src/index.md | 14 +++++++++ 4 files changed, 156 insertions(+) create mode 100644 docs/make.jl create mode 100644 docs/src/advi.md create mode 100644 docs/src/families.md create mode 100644 docs/src/index.md diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..d2a01d1b --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,17 @@ +#using AdvancedVI +using Documenter + +DocMeta.setdocmeta!( + AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true +) + +makedocs(; + sitename = "AdvancedVI.jl", + modules = [AdvancedVI], + format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), + pages = ["index.md", + "families.md", + "advi.md"], +) + +deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", devbranch="main") diff --git a/docs/src/advi.md b/docs/src/advi.md new file mode 100644 index 00000000..4f4a2eca --- /dev/null +++ b/docs/src/advi.md @@ -0,0 +1,67 @@ + +# [Automatic Differentiation Variational Inference](@id advi) +The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. +By maximizing ADVI objective, it is equivalent to solving the problem + +```math + \mathrm{minimize}_{\lambda \in \Lambda}\quad \mathrm{KL}\left(q_{\phi,\lambda}, \pi\right). +``` + +The key aspects of the ADVI objective are the followings: +1. The use of the reparameterization gradient estimator +2. Automatically match the support of the target posterior through "bijectors." + +Thanks to Item 2, the user is free to choose any unconstrained variational family, for which +bijectors will automatically match the potentially constrained support of the target. + +In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}`` +from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that +```math +z &\sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad +z &\stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} +``` +ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``. + +That is, + +```math +\begin{aligned} +\mathrm{ADVI}\left(\lambda\right) +&\triangleq +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + \log \pi\left( \phi^{-1}\left( \eta \right) \right) +\right] ++ \mathbb{H}\left(q_{\lambda}\right) ++ \log \lvert J_{\phi^{-1}}\left(\eta\right) \rvert \\ +&= +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + \log \pi\left( \phi^{-1}\left( \eta \right) \right) +\right] ++ +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + - \log q_{\lambda}\left( \eta \right) \lvert J_{\phi}\left(\eta\right) \rvert +\right] \\ +&= +\mathbb{E}_{z \sim q_{\phi,\lambda}}\left[ \log \pi\left(z\right) \right] ++ +\mathbb{H}\left(q_{\phi,\lambda}\right) +\end{aligned} +``` + +The idea of using the reparameterization gradient estimator for variational inference was first +coined by Titsias and Lázaro-Gredilla (2014). +Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by +Fjelde *et al.* (2017). + + +```@docs +ADVI +``` + +# References +1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. +3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. +4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. + + diff --git a/docs/src/families.md b/docs/src/families.md new file mode 100644 index 00000000..f203cf18 --- /dev/null +++ b/docs/src/families.md @@ -0,0 +1,58 @@ + +# [Variational Families](@id families) + +## Location-Scale Variational Family + +### Description +The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as +```math +z = C u + m, +``` +where ``C`` is the *scale* and ``m`` is the location variational parameter. +This family encompases many + + +### Constructors + +```@docs +VILocationScale +``` + +```@docs +VIFullRankGaussian +VIMeanFieldGaussian +``` + +### Examples + +A full-rank variational family can be formed by choosing +```@repl locscale +using AdvancedVI, LinearAlgebra +μ = zeros(2); +L = diagm(ones(2)) |> LowerTriangular; +``` + +A mean-field variational family can be formed by choosing +```@repl locscale +μ = zeros(2); +L = ones(2) |> Diagonal; +``` + +Gaussian variational family: +```@repl locscale +q = VIFullRankGaussian(μ, L) +q = VIMeanFieldGaussian(μ, L) +``` + +Sudent-T Variational Family: + +```@repl locscale +ν = 3 +q = VILocationScale(μ, L, StudentT(ν)) +``` + +Multivariate Laplace family: +```@repl locscale +q = VILocationScale(μ, L, Laplace()) +``` + diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..be326921 --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,14 @@ +```@meta +CurrentModule = AdvancedVI +``` + +# AdvancedVI + +Documentation for [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl). + +```@index +``` + +```@autodocs +Modules = [AdvancedVI] +``` From 619b1c05eaf669491f82406becb9a31dba1871cc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 22:34:32 +0100 Subject: [PATCH 058/206] remove unused epsilon argument in location scale --- src/distributions/location_scale.jl | 40 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index dc9c1b27..5eb371ad 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,31 +1,31 @@ """ + VILocationScale{L,R,D}(location::L, scale::S, dist::D) <: ContinuousMultivariateDistribution The [location scale] variational family broadly represents various variational families using `location` and `scale` variational parameters. -Multivariate Student-t variational family with ``\\nu``-degrees of freedom can -be constructed as: +It generally represents any distribution for which the sampling path can be +represented as the following: ```julia -q₀ = VILocationScale(μ, L, StudentT(ν), eps(Float32)) + d = length(location) + u = rand(dist, d) + z = scale*u + location ``` - """ -struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution +struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution location::L scale ::S dist ::D - epsilon ::R function VILocationScale(μ::AbstractVector{<:Real}, L::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, - q_base::ContinuousUnivariateDistribution, - epsilon::Real) + q_base::ContinuousUnivariateDistribution) # Restricting all the arguments to have the same types creates problems # with dual-variable-based AD frameworks. @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - new{typeof(μ), typeof(L), typeof(q_base), typeof(epsilon)}(μ, L, q_base, epsilon) + new{typeof(μ), typeof(L), typeof(q_base)}(μ, L, q_base) end end @@ -76,16 +76,22 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) return x += location end -function VIFullRankGaussian(μ::AbstractVector{T}, - L::AbstractTriangular{T}, - epsilon::Real = eps(T)) where {T <: Real} +""" + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + +This constructs a multivariate Gaussian distribution with a full rank covariance matrix. +""" +function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) where {T <: Real} q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base, epsilon) + VILocationScale(μ, L, q_base) end -function VIMeanFieldGaussian(μ::AbstractVector{T}, - L::Diagonal{T}, - epsilon::Real = eps(T)) where {T <: Real} +""" + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + +This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. +""" +function VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T <: Real} q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base, epsilon) + VILocationScale(μ, L, q_base) end From f1c02f02909ff15ac2ddc6276af8589c97cfedf8 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 22:39:16 +0100 Subject: [PATCH 059/206] add project file for documenter --- docs/Project.toml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 docs/Project.toml diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..fc885857 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,7 @@ +[deps] +AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" + +[compat] +Documenter = "0.26" \ No newline at end of file From b0f259a4c32ad293cf0edd236b42b132d7e959b5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 16 Jul 2023 02:55:03 +0100 Subject: [PATCH 060/206] refactor STL gradient calculation to use multiple dispatch --- src/AdvancedVI.jl | 6 +- src/distributions/location_scale.jl | 16 ++--- src/objectives/elbo/advi.jl | 97 +++++++++++++++++++++++------ src/objectives/elbo/entropy.jl | 11 +--- test/advi_locscale.jl | 6 +- test/models/normal.jl | 51 +++++++++++++++ test/models/normallognormal.jl | 4 +- test/models/utils.jl | 8 +++ 8 files changed, 160 insertions(+), 39 deletions(-) create mode 100644 test/models/normal.jl create mode 100644 test/models/utils.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e3dd85a8..9f93885c 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -73,8 +73,9 @@ abstract type AbstractControlVariate end function update end update(::Nothing, ::Nothing) = (nothing, nothing) -include("objectives/elbo/advi.jl") +# entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") +include("objectives/elbo/advi.jl") export ELBO, @@ -82,13 +83,14 @@ export ADVIEnergy, ClosedFormEntropy, StickingTheLandingEntropy, - MonteCarloEntropy + FullMonteCarloEntropy # Variational Families include("distributions/location_scale.jl") export + VILocationScale, VIFullRankGaussian, VIMeanFieldGaussian diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 5eb371ad..e901e8de 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,8 +1,8 @@ """ - VILocationScale{L,R,D}(location::L, scale::S, dist::D) <: ContinuousMultivariateDistribution + VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution -The [location scale] variational family broadly represents various variational +The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. It generally represents any distribution for which the sampling path can be @@ -18,14 +18,14 @@ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution scale ::S dist ::D - function VILocationScale(μ::AbstractVector{<:Real}, - L::Union{<:AbstractTriangular{<:Real}, + function VILocationScale(location::AbstractVector{<:Real}, + scale::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, - q_base::ContinuousUnivariateDistribution) + dist::ContinuousUnivariateDistribution) # Restricting all the arguments to have the same types creates problems # with dual-variable-based AD frameworks. - @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - new{typeof(μ), typeof(L), typeof(q_base)}(μ, L, q_base) + @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2)) + new{typeof(location), typeof(scale), typeof(dist)}(location, scale, dist) end end @@ -87,7 +87,7 @@ function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) whe end """ - VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. """ diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index e965ea73..e4e93327 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,9 +1,23 @@ """ - ADVI + ADVI( + prob, + n_samples::Int; + entropy::AbstractEntropyEstimator = ClosedFormEntropy(), + cv::Union{<:AbstractControlVariate, Nothing} = nothing, + b = Bijectors.identity + ) Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. +# Arguments +- `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. + - `logdensity` must be differentiable by the selected AD backend. +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. +- `entropy`: The estimator for the entropy term. +- `cv`: A control variate +- `b`: A bijector mapping the support of the base distribution to that of `prob`. + # Requirements - ``q_{\\lambda}`` implements `rand`. - ``\\pi`` must be differentiable @@ -39,40 +53,87 @@ end Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))") -skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy) - init(advi::ADVI) = init(advi.cv) -function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; - rng ::AbstractRNG = default_rng(), - n_samples ::Int = advi.n_samples, - ηs ::AbstractMatrix = rand(rng, q_η, n_samples), - q_η_entropy::ContinuousMultivariateDistribution = q_η) +function (advi::ADVI)( + rng::AbstractRNG, + q_η::ContinuousMultivariateDistribution, + ηs ::AbstractMatrix +) + n_samples = size(ηs, 2) 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples end - ℍ = advi.entropy(q_η_entropy, ηs) + ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ end -function estimate_gradient( +""" + (advi::ADVI)( + q_η::ContinuousMultivariateDistribution; + rng::AbstractRNG = Random.default_rng(), + n_samples::Int = advi.n_samples + ) + +Evaluate the ELBO using the ADVI formulation. + +# Arguments +- `q_η`: Variational approximation before applying a bijector (unconstrained support). +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. + +""" +function (advi::ADVI)( + q_η::ContinuousMultivariateDistribution; + rng::AbstractRNG = default_rng(), + n_samples::Int = advi.n_samples +) + ηs = rand(rng, q_η, n_samples) + advi(rng, q_η, ηs) +end + +function estimate_advi_gradient_maybe_stl!( rng::AbstractRNG, adbackend::AbstractADType, - advi::ADVI, - est_state, + advi::ADVI{P, B, StickingTheLandingEntropy, CV}, λ::Vector{<:Real}, restructure, - out::DiffResults.MutableDiffResult) - - # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(advi.entropy) ? restructure(λ) : nothing + out::DiffResults.MutableDiffResult +) where {P, B, CV} + q_η_stop = restructure(λ) + grad!(adbackend, λ, out) do λ′ + q_η = restructure(λ′) + ηs = rand(rng, q_η, advi.n_samples) + -advi(rng, q_η_stop, ηs) + end +end +function estimate_advi_gradient_maybe_stl!( + rng::AbstractRNG, + adbackend::AbstractADType, + advi::ADVI{P, B, <:Union{ClosedFormEntropy, FullMonteCarloEntropy}, CV}, + λ::Vector{<:Real}, + restructure, + out::DiffResults.MutableDiffResult +) where {P, B, CV} grad!(adbackend, λ, out) do λ′ q_η = restructure(λ′) - q_η_entropy = skip_entropy_gradient(advi.entropy) ? q_η_stop : q_η - -advi(q_η; rng, q_η_entropy) + ηs = rand(rng, q_η, advi.n_samples) + -advi(rng, q_η, ηs) end +end + +function estimate_gradient( + rng::AbstractRNG, + adbackend::AbstractADType, + advi::ADVI, + est_state, + λ::Vector{<:Real}, + restructure, + out::DiffResults.MutableDiffResult +) + estimate_advi_gradient_maybe_stl!( + rng, adbackend, advi, λ, restructure, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 994bdd4f..7f37b619 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -7,11 +7,9 @@ end skip_entropy_gradient(::ClosedFormEntropy) = false -struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end +abstract type MonteCarloEntropy <: AbstractEntropyEstimator end -MonteCarloEntropy() = MonteCarloEntropy{false}() - -Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()") +struct FullMonteCarloEntropy <: MonteCarloEntropy end """ StickingTheLandingEntropy() @@ -44,11 +42,8 @@ and higher variance when ``\\pi \\not\\approx q_{\\lambda}``. # Reference 1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. """ -StickingTheLandingEntropy() = MonteCarloEntropy{true}() - -skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding -Base.show(io::IO, entropy::MonteCarloEntropy{true}) = print(io, "StickingTheLandingEntropy()") +struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 40e5dace..2f19ca61 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -19,6 +19,8 @@ struct TestModel{M,L,S} end include("models/normallognormal.jl") +include("models/normal.jl") +include("models/utils.jl") @testset "advi" begin @testset "locscale" begin @@ -27,11 +29,13 @@ include("models/normallognormal.jl") (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, + :NormalMeanField => normal_meanfield, + :NormalFullRank => normal_fullrank, ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), - :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = MonteCarloEntropy()), + :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = FullMonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/models/normal.jl b/test/models/normal.jl new file mode 100644 index 00000000..a677af93 --- /dev/null +++ b/test/models/normal.jl @@ -0,0 +1,51 @@ + +struct TestMvNormal{M,S} + μ::M + Σ::S +end + +function LogDensityProblems.logdensity(model::TestMvNormal, θ) + @unpack μ, Σ = model + logpdf(MvNormal(μ, Σ), θ) +end + +function LogDensityProblems.dimension(model::TestMvNormal) + length(model.μ) +end + +function LogDensityProblems.capabilities(::Type{<:TestMvNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::TestMvNormal) + identity +end + +function normal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + L₀ = sample_cholesky(rng, n_dims) + ϵ = eps(realtype)*10 + Σ = (L₀*L₀' + ϵ*I) |> Hermitian + + Σ_chol = cholesky(Σ) + model = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol)) + + L = Σ_chol.L |> LowerTriangular + + TestModel(model, μ, L, n_dims, false) +end + +function normal_meanfield(realtype; rng = default_rng()) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = TestMvNormal(μ, PDMats.PDiagMat(σ)) + + L = σ |> Diagonal + + TestModel(model, μ, L, n_dims, true) +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 18e8b4a3..ca8c9a4d 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,8 +32,8 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(n_dims*2) + L₀_y = sample_cholesky(rng, n_dims) + ϵ = eps(realtype)*10 Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) diff --git a/test/models/utils.jl b/test/models/utils.jl new file mode 100644 index 00000000..c1a9a407 --- /dev/null +++ b/test/models/utils.jl @@ -0,0 +1,8 @@ + +function sample_cholesky(rng::AbstractRNG, n_dims::Int) + A = randn(rng, n_dims, n_dims) + L = tril(A) + idx = diagind(L) + @. L[idx] = log(exp(L[idx]) + 1) + L |> LowerTriangular +end From b72c2585a1d3e461d9903884d16d9b019c11e828 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 16 Jul 2023 03:49:08 +0100 Subject: [PATCH 061/206] fix type bugs, relax test threshold for the exact inference tests --- test/advi_locscale.jl | 8 ++++---- test/models/normal.jl | 5 ++--- test/models/normallognormal.jl | 5 ++--- test/models/utils.jl | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 2f19ca61..1552be5e 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -72,7 +72,7 @@ include("models/utils.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-3), + optimizer = Optimisers.Adam(1e-2), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -82,7 +82,7 @@ include("models/utils.jl") L = q.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ₀/√T + @test Δλ ≤ Δλ₀/T^(1/4) @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -91,7 +91,7 @@ include("models/utils.jl") rng = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-3), + optimizer = Optimisers.Adam(realtype(1e-2)), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -102,7 +102,7 @@ include("models/utils.jl") rng_repl = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-3), + optimizer = Optimisers.Adam(realtype(1e-2)), progress = PROGRESS, rng = rng_repl, adbackend = adbackend, diff --git a/test/models/normal.jl b/test/models/normal.jl index a677af93..f60ad5f3 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -25,9 +25,8 @@ function normal_fullrank(realtype; rng = default_rng()) n_dims = 5 μ = randn(rng, realtype, n_dims) - L₀ = sample_cholesky(rng, n_dims) - ϵ = eps(realtype)*10 - Σ = (L₀*L₀' + ϵ*I) |> Hermitian + L₀ = sample_cholesky(rng, realtype, n_dims) + Σ = L₀*L₀' |> Hermitian Σ_chol = cholesky(Σ) model = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol)) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index ca8c9a4d..cab73cce 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,9 +32,8 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L₀_y = sample_cholesky(rng, n_dims) - ϵ = eps(realtype)*10 - Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian + L₀_y = sample_cholesky(rng, realtype, n_dims) + Σ_y = L₀_y*L₀_y' |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) diff --git a/test/models/utils.jl b/test/models/utils.jl index c1a9a407..3d483c46 100644 --- a/test/models/utils.jl +++ b/test/models/utils.jl @@ -1,6 +1,6 @@ -function sample_cholesky(rng::AbstractRNG, n_dims::Int) - A = randn(rng, n_dims, n_dims) +function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) + A = randn(rng, type, n_dims, n_dims) L = tril(A) idx = diagind(L) @. L[idx] = log(exp(L[idx]) + 1) From a8df9eb8b635e9805e3f307b7b5b64ccb4f1f970 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:33:15 +0100 Subject: [PATCH 062/206] refactor derivative utils to match NormalizingFlows.jl with extras --- Project.toml | 20 +++++++-- ext/AdvancedVIEnzymeExt.jl | 26 +++++++++++ ext/AdvancedVIForwardDiffExt.jl | 29 ++++++++++++ ext/AdvancedVIReverseDiffExt.jl | 23 ++++++++++ ext/AdvancedVIZygoteExt.jl | 24 ++++++++++ src/AdvancedVI.jl | 79 ++++++++++++++++++--------------- src/compat/enzyme.jl | 16 ------- src/compat/reversediff.jl | 19 -------- src/compat/zygote.jl | 13 ------ src/grad.jl | 30 ------------- src/objectives/elbo/advi.jl | 6 ++- test/ad.jl | 7 ++- 12 files changed, 167 insertions(+), 125 deletions(-) create mode 100644 ext/AdvancedVIEnzymeExt.jl create mode 100644 ext/AdvancedVIForwardDiffExt.jl create mode 100644 ext/AdvancedVIReverseDiffExt.jl create mode 100644 ext/AdvancedVIZygoteExt.jl delete mode 100644 src/compat/enzyme.jl delete mode 100644 src/compat/reversediff.jl delete mode 100644 src/compat/zygote.jl delete mode 100644 src/grad.jl diff --git a/Project.toml b/Project.toml index cf698f7a..ab00d674 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,10 @@ version = "0.2.4" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -21,24 +21,36 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" +DiffResults = "1.0.3" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" -ForwardDiff = "0.10.3" +ForwardDiff = "0.10.25" +LogDensityProblems = "2.1.1" +Optimisers = "0.2.16" ProgressMeter = "1.0.0" Requires = "0.5, 1.0" +ReverseDiff = "1.14" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" julia = "1.6" [extras] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 00000000..8333299f --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,26 @@ + +module AdvancedVIEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..Enzyme + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y = f(θ) + DiffResults.value!(out, y) + ∇θ = DiffResults.gradient(out) + fill!(∇θ, zero(T)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + return out +end + +end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl new file mode 100644 index 00000000..e6b03af2 --- /dev/null +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -0,0 +1,29 @@ + +module AdvancedVIForwardDiffExt + +if isdefined(Base, :get_extension) + using ForwardDiff + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..ForwardDiff + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# extract chunk size from AutoForwardDiff +getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + chunk_size = getchunksize(ad) + config = if isnothing(chunk_size) + ForwardDiff.GradientConfig(f, θ) + else + ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + end + ForwardDiff.gradient!(out, f, θ, config) + return out +end + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 00000000..fd7fbaab --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,23 @@ + +module AdvancedVIReverseDiffExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using ReverseDiff +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..ReverseDiff +end + +# ReverseDiff without compiled tape +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + tp = ReverseDiff.GradientTape(f, θ) + ReverseDiff.gradient!(out, tp, θ) + return out +end + +end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl new file mode 100644 index 00000000..b447d071 --- /dev/null +++ b/ext/AdvancedVIZygoteExt.jl @@ -0,0 +1,24 @@ + +module AdvancedVIZygoteExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Zygote +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Zygote +end + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y, back = Zygote.pullback(f, θ) + ∇θ = back(one(T)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, first(∇θ)) + return out +end + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9f93885c..697f3c83 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -21,9 +21,9 @@ using LinearAlgebra: AbstractTriangular using LogDensityProblems -using ADTypes +using ADTypes, DiffResults using ADTypes: AbstractADType -using ForwardDiff, Tracker + using FillArrays using PDMats @@ -34,29 +34,23 @@ using StatsBase: entropy const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) -using Requires -function __init__() - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - end -end - +# derivatives """ - grad!(f, λ, out) - -Computes the gradients of the objective f. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. + value_and_gradient!( + ad::ADTypes.AbstractADType, + f, + θ::AbstractVector{T}, + out::DiffResults.MutableDiffResult + ) where {T<:Real} + +Compute the value and gradient of a function `f` at `θ` using the automatic +differentiation backend `ad`. The result is stored in `out`. +The function `f` must return a scalar value. The gradient is stored in `out` as a +vector of the same length as `θ`. """ -function grad! end +function value_and_gradient! end -include("grad.jl") +export value_and_gradient! # estimators abstract type AbstractVariationalObjective end @@ -94,21 +88,8 @@ export VIFullRankGaussian, VIMeanFieldGaussian -""" - optimize(model, alg::VariationalInference) - optimize(model, alg::VariationalInference, q::VariationalPosterior) - optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) - -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. - -# Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior -""" +# Optimization Routine + function optimize end include("optimize.jl") @@ -117,4 +98,28 @@ export optimize include("utils.jl") + +# optional dependencies +if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base + using Requires +end + +using Requires +function __init__() + @static if !isdefined(Base, :get_extension) + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") + end + @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIForwardDiffExt.jl") + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/AdvancedVIReverseDiffExt.jl") + end + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") + end + end +end end # module + diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl deleted file mode 100644 index cab50862..00000000 --- a/src/compat/enzyme.jl +++ /dev/null @@ -1,16 +0,0 @@ - -function AdvancedVI.grad!( - f::Function, - ::AutoEnzyme, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(λ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) - return out -end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl deleted file mode 100644 index 4d8f87d8..00000000 --- a/src/compat/reversediff.jl +++ /dev/null @@ -1,19 +0,0 @@ -using .ReverseDiff: compile, GradientTape -using .ReverseDiff.DiffResults: GradientResult - -tape(f, x) = GradientTape(f, x) -function taperesult(f, x) - return tape(f, x), GradientResult(x) -end - -# Precompiled tapes are not properly supported yet. -function AdvancedVI.grad!( - f::Function, - ::AutoReverseDiff, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - tp = tape(f, λ) - ReverseDiff.gradient!(out, tp, λ) - return out -end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl deleted file mode 100644 index f1a29b87..00000000 --- a/src/compat/zygote.jl +++ /dev/null @@ -1,13 +0,0 @@ - -function AdvancedVI.grad!( - f::Function, - ::AutoZygote, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - y, back = Zygote.pullback(f, λ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out -end diff --git a/src/grad.jl b/src/grad.jl deleted file mode 100644 index e68e1623..00000000 --- a/src/grad.jl +++ /dev/null @@ -1,30 +0,0 @@ - -# default implementations -function grad!( - f::Function, - adtype::AutoForwardDiff{chunksize}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) where {chunksize} - # Set chunk size and do ForwardMode. - config = if isnothing(chunksize) - ForwardDiff.GradientConfig(f, λ) - else - ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunksize)) - end - ForwardDiff.gradient!(out, f, λ, config) -end - -function grad!( - f::Function, - ::AutoTracker, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - λ_tracked = Tracker.param(λ) - y = f(λ_tracked) - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(λ_tracked)) -end diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index e4e93327..d308db0a 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -101,11 +101,12 @@ function estimate_advi_gradient_maybe_stl!( out::DiffResults.MutableDiffResult ) where {P, B, CV} q_η_stop = restructure(λ) - grad!(adbackend, λ, out) do λ′ + f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η_stop, ηs) end + grad!(adbackend, f, λ, out) end function estimate_advi_gradient_maybe_stl!( @@ -116,11 +117,12 @@ function estimate_advi_gradient_maybe_stl!( restructure, out::DiffResults.MutableDiffResult ) where {P, B, CV} - grad!(adbackend, λ, out) do λ′ + f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η, ηs) end + value_and_gradient!(adbackend, f, λ, out) end function estimate_gradient( diff --git a/test/ad.jl b/test/ad.jl index 1efa536b..9df26d9f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -9,15 +9,14 @@ using ADTypes :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Tracker => AutoTracker(), - :Enzyme => AutoEnzyme(), + # :Enzyme => AutoEnzyme(), # Currently not tested against. ) D = 10 A = randn(D, D) λ = randn(D) grad_buf = DiffResults.GradientResult(λ) - AdvancedVI.grad!(adsymbol, λ, grad_buf) do λ′ - λ′'*A*λ′ / 2 - end + f(λ′) = λ′'*A*λ′ / 2 + AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A')*λ/2 From e8db6a7ac62d1916969aaaeea336677ba19eafa0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:34:14 +0100 Subject: [PATCH 063/206] add documentation, refactor optimize --- docs/Project.toml | 2 +- docs/make.jl | 11 ++-- docs/src/advi.md | 36 ++++++++++- docs/src/families.md | 32 ++++++---- src/objectives/elbo/entropy.jl | 32 ---------- src/optimize.jl | 106 +++++++++++++++++++++------------ 6 files changed, 130 insertions(+), 89 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index fc885857..c625d07f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] -Documenter = "0.26" \ No newline at end of file +Documenter = "0.26, 0.27" \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index d2a01d1b..b9a8eb5f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,5 @@ -#using AdvancedVI + +using AdvancedVI using Documenter DocMeta.setdocmeta!( @@ -9,9 +10,9 @@ makedocs(; sitename = "AdvancedVI.jl", modules = [AdvancedVI], format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["index.md", - "families.md", - "advi.md"], + pages = ["Home" => "index.md", + "Families" => "families.md", + "ADVI" => "advi.md"], ) -deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", devbranch="main") +deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/advi.md b/docs/src/advi.md index 4f4a2eca..0597e03c 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -1,5 +1,8 @@ # [Automatic Differentiation Variational Inference](@id advi) + +# Introduction + The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. By maximizing ADVI objective, it is equivalent to solving the problem @@ -17,8 +20,8 @@ bijectors will automatically match the potentially constrained support of the ta In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}`` from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that ```math -z &\sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad -z &\stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} +z \sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} ``` ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``. @@ -53,15 +56,44 @@ coined by Titsias and Lázaro-Gredilla (2014). Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by Fjelde *et al.* (2017). +# The `ADVI` Objective ```@docs ADVI ``` +# The "Sticking the Landing" Control Variate +The STL control variate was proposed by Roeder *et al.* (2017). +By slightly modifying the differentiation path, it implicitly forms a control variate of the form of +```math +\mathrm{CV}_{\mathrm{STL}}\left(z\right) \triangleq \mathbb{H}\left(q_{\lambda}\right) + \log q_{\lambda}\left(z\right), +``` +which has a mean of zero. + +Adding this to the closed-form entropy ELBO estimator yields the STL estimator: +```math +\begin{aligned} + \widehat{\mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) + &\triangleq \mathbb{E}\left[ \log \pi \left(z\right) \right] - \log q_{\lambda} \left(z\right) \\ + &= \mathbb{E}\left[ \log \pi\left(z\right) \right] + + \mathbb{H}\left(q_{\lambda}\right) - \mathrm{CV}_{\mathrm{STL}}\left(z\right) \\ + &= \widehat{\mathrm{ELBO}}\left(\lambda\right) + - \mathrm{CV}_{\mathrm{STL}}\left(z\right), +\end{aligned} +``` +which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. +The conditions for which the STL estimator results in lower variance is still an active subject for research. + +The STL control variate can be used by changing the entropy estimator as follows: +```julia +ADVI(prob, n_samples; entropy = StickingTheLanding(), b = bijector) +``` + # References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. 3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. 4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. +5. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. diff --git a/docs/src/families.md b/docs/src/families.md index f203cf18..d326ce7a 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -1,5 +1,5 @@ -# [Variational Families](@id families) +# Variational Families ## Location-Scale Variational Family @@ -25,34 +25,42 @@ VIMeanFieldGaussian ### Examples -A full-rank variational family can be formed by choosing ```@repl locscale -using AdvancedVI, LinearAlgebra +using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); -L = diagm(ones(2)) |> LowerTriangular; -``` - -A mean-field variational family can be formed by choosing -```@repl locscale -μ = zeros(2); -L = ones(2) |> Diagonal; ``` Gaussian variational family: ```@repl locscale +L = diagm(ones(2)) |> LowerTriangular; q = VIFullRankGaussian(μ, L) + +L = ones(2) |> Diagonal; q = VIMeanFieldGaussian(μ, L) ``` Sudent-T Variational Family: ```@repl locscale -ν = 3 -q = VILocationScale(μ, L, StudentT(ν)) +ν = 3; + +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = VILocationScale(μ, L, TDist(ν)) + +# Mean-Field +L = ones(2) |> Diagonal; +q = VILocationScale(μ, L, TDist(ν)) ``` Multivariate Laplace family: ```@repl locscale +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = VILocationScale(μ, L, Laplace()) + +# Mean-Field +L = ones(2) |> Diagonal; q = VILocationScale(μ, L, Laplace()) ``` diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 7f37b619..e9f180f5 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -11,38 +11,6 @@ abstract type MonteCarloEntropy <: AbstractEntropyEstimator end struct FullMonteCarloEntropy <: MonteCarloEntropy end -""" - StickingTheLandingEntropy() - -# Explanation - -The STL estimator forms a control variate of the form of - -```math -\\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) = - \\mathbb{E}\\left[ -\\log q\\left(z\\right) \\right] - + \\log q\\left(z\\right) = \\mathbb{H}\\left(q_{\\lambda}\\right) + \\log q_{\\lambda}\\left(z\\right), -``` -where, for the score term, the gradient is stopped from propagating. - -Adding this to the closed-form entropy ELBO estimator yields the STL estimator: -```math -\\begin{aligned} - \\widehat{\\mathrm{ELBO}}_{\\mathrm{STL}}\\left(\\lambda\\right) - &\\triangleq \\mathbb{E}\\left[ \\log \\pi \\left(z\\right) \\right] - \\log q_{\\lambda} \\left(z\\right) \\\\ - &= \\mathbb{E}\\left[ \\log \\pi\\left(z\\right) \\right] - + \\mathbb{H}\\left(q_{\\lambda}\\right) - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) \\\\ - &= \\widehat{\\mathrm{ELBO}}\\left(\\lambda\\right) - - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right), -\\end{aligned} -``` -which has the same expectation, but lower variance when ``\\pi \\approx q_{\\lambda}``, -and higher variance when ``\\pi \\not\\approx q_{\\lambda}``. - -# Reference -1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. -""" - struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) diff --git a/src/optimize.jl b/src/optimize.jl index 8b36df04..ef16dcce 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -4,73 +4,105 @@ function pm_next!(pm, stats::NamedTuple) end """ - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) + optimize( + objective ::AbstractVariationalObjective, + restructure, + λ₀ ::AbstractVector{<:Real}, + n_max_iter ::Int; + kwargs... + ) -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. +Optimize the variational objective `objective` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`. + + optimize( + objective ::AbstractVariationalObjective, + q, + n_max_iter::Int; + kwargs... + ) + +Optimize the variational objective `objective` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface. + +# Arguments +- `objective`: Variational Objective. +- `λ₀`: Initial value of the variational parameters. +- `restructure`: Function that reconstructs the variational approximation from the flattened parameters. +- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. +- `n_max_iter`: Maximum number of iterations. + +# Keyword Arguments +- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) +- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) +- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) +- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) +- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) + +# Returns +- `λ`: Variational parameters optimizing the variational objective. +- `stats`: Statistics gathered during inference. +- `opt_state`: Final state of the optimiser. """ function optimize( - objective ::AbstractVariationalObjective, + objective ::AbstractVariationalObjective, restructure, - λ ::AbstractVector{<:Real}, - n_max_iter::Int; - optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), - rng ::AbstractRNG = default_rng(), - progress ::Bool = true, - callback! = nothing, - terminate = (args...) -> false, - adbackend::AbstractADType = AutoForwardDiff(), + λ₀ ::AbstractVector{<:Real}, + n_max_iter ::Int; + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + rng ::AbstractRNG = default_rng(), + show_progress::Bool = true, + callback! = nothing, + #convergence = (args...) -> (false, con_state), + adbackend::AbstractADType = AutoForwardDiff(), + prog = ProgressMeter.Progress( + n_max_iter; + desc = "Optimizing", + barlen = 31, + showspeed = true, + enabled = show_progress + ) ) - opt_state = Optimisers.init(optimizer, λ) + λ = copy(λ₀) + opt_state = Optimisers.setup(optimizer, λ) est_state = init(objective) + #con_state = init(convergence) grad_buf = DiffResults.GradientResult(λ) - - prog = ProgressMeter.Progress(n_max_iter; - barlen = 0, - enabled = progress, - showspeed = true) - stats = Vector{NamedTuple}(undef, n_max_iter) + stats = NamedTuple[] for t = 1:n_max_iter stat = (iteration=t,) grad_buf, est_state, stat′ = estimate_gradient( rng, adbackend, objective, est_state, λ, restructure, grad_buf) - g = DiffResults.gradient(grad_buf) stat = merge(stat, stat′) - opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g) - Optimisers.subtract!(λ, Δλ) - - stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) + g = DiffResults.gradient(grad_buf) + opt_state, λ = Optimisers.update!(opt_state, λ, g) + stat′ = (iteration=t, gradient_norm=norm(g)) stat = merge(stat, stat′) - q = restructure(λ) - if !isnothing(callback!) - stat′ = callback!(q, stat) + stat′ = callback!(; est_state, stat, restructure, λ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end AdvancedVI.DEBUG && @debug "Step $t" stat... pm_next!(prog, stat) - stats[t] = stat + push!(stats, stat) - # Termination decision is work in progress - if terminate(rng, λ, q, objective, stat) - stats = stats[1:t] - break - end + #convergence(rng, t, restructure, λ, q, objective, stat) + #if terminate() + # break + #end end - λ, stats + λ, map(identity, stats), opt_state end -function optimize(objective::AbstractVariationalObjective, - q, +function optimize(objective ::AbstractVariationalObjective, + q₀, n_max_iter::Int; kwargs...) - λ, restructure = Optimisers.destructure(q) + λ, restructure = Optimisers.destructure(q₀) λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...) restructure(λ), stats end From 65a2b37d354798dd40161dc897f7124b5b68b857 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:49:57 +0100 Subject: [PATCH 064/206] fix bug missing extension --- Project.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Project.toml b/Project.toml index ab00d674..ffc41a4b 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,12 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[extensions] +AdvancedVIEnzymeExt = "Enzyme" +AdvancedVIForwardDiffExt = "ForwardDiff" +AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVIZygoteExt = "Zygote" + [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" From 1a02051f6fb8e2c59b39e7faa58c91db7ca589b3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:50:24 +0100 Subject: [PATCH 065/206] remove tracker from tests --- test/ad.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 9df26d9f..2c4f802a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,6 @@ using ReTest -using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote +using ForwardDiff, ReverseDiff, Enzyme, Zygote using ADTypes @testset "ad" begin @@ -8,7 +8,6 @@ using ADTypes :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Tracker => AutoTracker(), # :Enzyme => AutoEnzyme(), # Currently not tested against. ) D = 10 From d8b5ea5a153e5a484972c8c46e98a58e0b958b95 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:50:50 +0100 Subject: [PATCH 066/206] remove export for internal derivative utils --- src/AdvancedVI.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 697f3c83..a1cf360a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -50,8 +50,6 @@ vector of the same length as `θ`. """ function value_and_gradient! end -export value_and_gradient! - # estimators abstract type AbstractVariationalObjective end @@ -104,11 +102,10 @@ if !isdefined(Base, :get_extension) # check whether :get_extension is defined in using Requires end -using Requires function __init__() @static if !isdefined(Base, :get_extension) - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/AdvancedVIZygoteExt.jl") + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") end @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/AdvancedVIForwardDiffExt.jl") @@ -116,10 +113,11 @@ function __init__() @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/AdvancedVIReverseDiffExt.jl") end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/AdvancedVIEnzymeExt.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") end end end -end # module + +end From 818bc2c33fb7513681c06bb6a99cf341c97957dc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:28:47 +0100 Subject: [PATCH 067/206] fix test errors, old interface --- src/optimize.jl | 4 ++-- test/advi_locscale.jl | 36 ++++++++++++++++++------------------ test/runtests.jl | 4 +++- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index ef16dcce..7c876b39 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -103,6 +103,6 @@ function optimize(objective ::AbstractVariationalObjective, n_max_iter::Int; kwargs...) λ, restructure = Optimisers.destructure(q₀) - λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...) - restructure(λ), stats + λ, stats, opt_state = optimize(objective, restructure, λ, n_max_iter; kwargs...) + restructure(λ), stats, opt_state end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 1552be5e..d4ef7aec 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -69,13 +69,13 @@ include("models/utils.jl") obj = objective(model, b⁻¹, 10) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats = optimize( + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-2), - progress = PROGRESS, - rng = rng, - adbackend = adbackend, + optimizer = Optimisers.Adam(1e-2), + show_progress = PROGRESS, + rng = rng, + adbackend = adbackend, ) μ = q.location @@ -88,24 +88,24 @@ include("models/utils.jl") end @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats = optimize( + rng = Philox4x(UInt64, seed, 8) + q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), - progress = PROGRESS, - rng = rng, - adbackend = adbackend, + optimizer = Optimisers.Adam(realtype(1e-2)), + show_progress = PROGRESS, + rng = rng, + adbackend = adbackend, ) μ = q.location L = q.scale - rng_repl = Philox4x(UInt64, seed, 8) - q, stats = optimize( + rng_repl = Philox4x(UInt64, seed, 8) + q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), - progress = PROGRESS, - rng = rng_repl, - adbackend = adbackend, + optimizer = Optimisers.Adam(realtype(1e-2)), + show_progress = PROGRESS, + rng = rng_repl, + adbackend = adbackend, ) μ_repl = q.location L_repl = q.scale diff --git a/test/runtests.jl b/test/runtests.jl index ddc1d09c..68225fd9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ -using Comonicon +using ReTest using ReTest: @testset, @test + +using Comonicon using Random using Random123 using Statistics From 215abf34639e76b59d3d8b7ad1b64d24ec7500e0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:29:06 +0100 Subject: [PATCH 068/206] fix wrong derivative interface, add documentation --- src/objectives/elbo/advi.jl | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index d308db0a..8bc14bc9 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,26 +1,21 @@ """ - ADVI( - prob, - n_samples::Int; - entropy::AbstractEntropyEstimator = ClosedFormEntropy(), - cv::Union{<:AbstractControlVariate, Nothing} = nothing, - b = Bijectors.identity - ) + ADVI(prob, n_samples; kwargs...) Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. # Arguments - `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. - - `logdensity` must be differentiable by the selected AD backend. -- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. -- `entropy`: The estimator for the entropy term. -- `cv`: A control variate -- `b`: A bijector mapping the support of the base distribution to that of `prob`. +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.) + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) +- `cv`: A control variate. +- `b`: A bijector mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) # Requirements - ``q_{\\lambda}`` implements `rand`. -- ``\\pi`` must be differentiable +- `logdensity(prob)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ @@ -106,7 +101,7 @@ function estimate_advi_gradient_maybe_stl!( ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η_stop, ηs) end - grad!(adbackend, f, λ, out) + value_and_gradient!(adbackend, f, λ, out) end function estimate_advi_gradient_maybe_stl!( From 88ad7680a928932be97e1f075d5cd1c0d497a651 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:29:25 +0100 Subject: [PATCH 069/206] update documentation --- docs/src/advi.md | 17 ++++++++----- docs/src/families.md | 44 +++++++++++++++++++++------------- src/objectives/elbo/entropy.jl | 9 +++++++ 3 files changed, 48 insertions(+), 22 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 0597e03c..37b3541b 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -1,7 +1,7 @@ # [Automatic Differentiation Variational Inference](@id advi) -# Introduction +## Introduction The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. By maximizing ADVI objective, it is equivalent to solving the problem @@ -56,13 +56,13 @@ coined by Titsias and Lázaro-Gredilla (2014). Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by Fjelde *et al.* (2017). -# The `ADVI` Objective +## The `ADVI` Objective ```@docs ADVI ``` -# The "Sticking the Landing" Control Variate +## The `StickingTheLanding` Control Variate The STL control variate was proposed by Roeder *et al.* (2017). By slightly modifying the differentiation path, it implicitly forms a control variate of the form of ```math @@ -84,12 +84,17 @@ Adding this to the closed-form entropy ELBO estimator yields the STL estimator: which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. The conditions for which the STL estimator results in lower variance is still an active subject for research. -The STL control variate can be used by changing the entropy estimator as follows: +The STL control variate can be used by changing the entropy estimator using the following object: +```@docs +StickingTheLandingEntropy +``` + +For example: ```julia -ADVI(prob, n_samples; entropy = StickingTheLanding(), b = bijector) +ADVI(prob, n_samples; entropy = StickingTheLandingEntropy(), b = bijector) ``` -# References +## References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. 3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. diff --git a/docs/src/families.md b/docs/src/families.md index d326ce7a..e6eaa91b 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -1,18 +1,26 @@ -# Variational Families +# Location-Scale Variational Family -## Location-Scale Variational Family - -### Description +## Description The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as ```math -z = C u + m, +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} z = C u + m;\quad u \sim \varphi ``` -where ``C`` is the *scale* and ``m`` is the location variational parameter. -This family encompases many - +where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. +``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. +The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. +The probability density is given by +```math + q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) +``` +and the entropy is given as +```math + \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|, +``` +where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. -### Constructors +## Constructors ```@docs VILocationScale @@ -23,15 +31,13 @@ VIFullRankGaussian VIMeanFieldGaussian ``` -### Examples +## Gaussian Variational Families -```@repl locscale +Gaussian variational family: +```julia using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); -``` -Gaussian variational family: -```@repl locscale L = diagm(ones(2)) |> LowerTriangular; q = VIFullRankGaussian(μ, L) @@ -39,9 +45,12 @@ L = ones(2) |> Diagonal; q = VIMeanFieldGaussian(μ, L) ``` +## Non-Gaussian Variational Families Sudent-T Variational Family: -```@repl locscale +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); ν = 3; # Full-Rank @@ -54,7 +63,10 @@ q = VILocationScale(μ, L, TDist(ν)) ``` Multivariate Laplace family: -```@repl locscale +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + # Full-Rank L = diagm(ones(2)) |> LowerTriangular; q = VILocationScale(μ, L, Laplace()) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index e9f180f5..0edc47f4 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -11,6 +11,15 @@ abstract type MonteCarloEntropy <: AbstractEntropyEstimator end struct FullMonteCarloEntropy <: MonteCarloEntropy end +""" + StickingTheLandingEntropy() + +The "sticking the landing" entropy estimator. + +# Requirements +- `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +""" struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) From e66935bb2881a61cf137ff74899e7117c53a9f46 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:34:11 +0100 Subject: [PATCH 070/206] add doc build CI --- .github/workflows/CI.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9731f20c..158da963 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,3 +61,30 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info + docs: + name: Documentation + runs-on: ubuntu-latest + permissions: + contents: write + statuses: write + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: '1' + - name: Configure doc environment + run: | + julia --project=docs/ -e ' + using Pkg + Pkg.develop(PackageSpec(path=pwd())) + Pkg.instantiate()' + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - run: | + julia --project=docs -e ' + using Documenter: DocMeta, doctest + using AdvancedVI + DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) + doctest(AdvancedVI)' From 9f1c647a6fb2b945754e808dcb608e3f19c4cae8 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:47:56 +0100 Subject: [PATCH 071/206] remove convergence criterion for now --- docs/src/families.md | 2 +- src/optimize.jl | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/docs/src/families.md b/docs/src/families.md index e6eaa91b..8ae48be3 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -5,7 +5,7 @@ The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} z = C u + m;\quad u \sim \varphi +z \stackrel{d}{=} C u + m;\quad u \sim \varphi ``` where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. ``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. diff --git a/src/optimize.jl b/src/optimize.jl index 7c876b39..0f2d29e9 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -51,7 +51,6 @@ function optimize( rng ::AbstractRNG = default_rng(), show_progress::Bool = true, callback! = nothing, - #convergence = (args...) -> (false, con_state), adbackend::AbstractADType = AutoForwardDiff(), prog = ProgressMeter.Progress( n_max_iter; @@ -64,7 +63,6 @@ function optimize( λ = copy(λ₀) opt_state = Optimisers.setup(optimizer, λ) est_state = init(objective) - #con_state = init(convergence) grad_buf = DiffResults.GradientResult(λ) stats = NamedTuple[] @@ -89,11 +87,6 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) - - #convergence(rng, t, restructure, λ, q, objective, stat) - #if terminate() - # break - #end end λ, map(identity, stats), opt_state end From c8b3ee3ed7ec43051631462b7674a7c1d66722d7 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:54:12 +0100 Subject: [PATCH 072/206] remove outdated export --- src/AdvancedVI.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a1cf360a..1677be62 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -72,7 +72,6 @@ include("objectives/elbo/advi.jl") export ELBO, ADVI, - ADVIEnergy, ClosedFormEntropy, StickingTheLandingEntropy, FullMonteCarloEntropy From afda1a19527f4197b25a50fcae8e52cdeace660b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 20:53:42 +0100 Subject: [PATCH 073/206] update documentation --- docs/make.jl | 9 +++-- docs/src/index.md | 16 ++++----- docs/src/{families.md => locscale.md} | 4 +-- docs/src/started.md | 51 +++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 13 deletions(-) rename docs/src/{families.md => locscale.md} (96%) create mode 100644 docs/src/started.md diff --git a/docs/make.jl b/docs/make.jl index b9a8eb5f..ca21b5fd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,9 +10,12 @@ makedocs(; sitename = "AdvancedVI.jl", modules = [AdvancedVI], format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["Home" => "index.md", - "Families" => "families.md", - "ADVI" => "advi.md"], + pages = ["AdvancedVI" => "index.md", + "Getting Started" => "started.md", + "ELBO Maximization" => [ + "Automatic Differentiation VI" => "advi.md", + "Location Scale Family" => "locscale.md", + ]], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/index.md b/docs/src/index.md index be326921..dea6d405 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,11 +4,11 @@ CurrentModule = AdvancedVI # AdvancedVI -Documentation for [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl). - -```@index -``` - -```@autodocs -Modules = [AdvancedVI] -``` +## Introduction +[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. +VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. +`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. + +## Provided Algorithms +`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization: +- [Automatic Differentiation Variational Inference](@ref advi) diff --git a/docs/src/families.md b/docs/src/locscale.md similarity index 96% rename from docs/src/families.md rename to docs/src/locscale.md index 8ae48be3..a4bc2dc1 100644 --- a/docs/src/families.md +++ b/docs/src/locscale.md @@ -1,7 +1,7 @@ -# Location-Scale Variational Family +# [Location-Scale Variational Family](@id locscale) -## Description +## Introduction The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad diff --git a/docs/src/started.md b/docs/src/started.md new file mode 100644 index 00000000..faff6166 --- /dev/null +++ b/docs/src/started.md @@ -0,0 +1,51 @@ + +# [Getting Started with `AdvancedVI`](@id getting_started) + +## General Usage +Each VI algorithm should provide the following: +1. A variational family +2. A variational objective + +Feeding these two into `optimize` runs the inference procedure. + +```@docs +optimize +``` + +## `ADVI` Example Using `Turing` + +```julia +using Turing +using Bijectors +using Optimisers +using ForwardDiff +using ADTypes + +import AdvancedVI as AVI + +μ_y, σ_y = 1.0, 1.0 +μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0] + +Turing.@model function normallognormal() + y ~ LogNormal(μ_y, σ_y) + z ~ MvNormal(μ_z, Σ_z) +end +model = normallognormal() +b = Bijectors.bijector(model) +b⁻¹ = inverse(b) +prob = DynamicPPL.LogDensityFunction(model) +d = LogDensityProblems.dimension(prob) + +μ = randn(d) +L = Diagonal(ones(d)) +q = AVI.MeanFieldGaussian(μ, L) + +n_max_iter = 10^4 +q, stats = AVI.optimize( + AVI.ADVI(prob, b⁻¹, 10), + q, + n_max_iter; + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +) +``` From 0d37acea1dd96c95d7cef427be7d84fee8d95c09 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 21:12:02 +0100 Subject: [PATCH 074/206] update documentation --- docs/src/started.md | 55 ++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/docs/src/started.md b/docs/src/started.md index faff6166..26c75a79 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -2,11 +2,13 @@ # [Getting Started with `AdvancedVI`](@id getting_started) ## General Usage -Each VI algorithm should provide the following: -1. A variational family -2. A variational objective +Each VI algorithm provides the followings: +1. Variational families supported by each VI algorithm. +2. A variational objective corresponding to the VI algorithm. +Note that each variational family is subject to its own constraints. +Thus, please refer to the documentation of the variational inference algorithm of interest. -Feeding these two into `optimize` runs the inference procedure. +To use `AdvancedVI`, a user needs to select a `variational family`, `variational objective`, and feed them into `optimize`. ```@docs optimize @@ -14,14 +16,10 @@ optimize ## `ADVI` Example Using `Turing` +In this tutorial, we'll use `Turing` to define a basic `normal-log-normal` model. +ADVI with log bijectors is able to infer this model exactly. ```julia using Turing -using Bijectors -using Optimisers -using ForwardDiff -using ADTypes - -import AdvancedVI as AVI μ_y, σ_y = 1.0, 1.0 μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0] @@ -31,18 +29,43 @@ Turing.@model function normallognormal() z ~ MvNormal(μ_z, Σ_z) end model = normallognormal() +``` + +Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. +Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. +```julia +using Bijectors + b = Bijectors.bijector(model) b⁻¹ = inverse(b) -prob = DynamicPPL.LogDensityFunction(model) -d = LogDensityProblems.dimension(prob) +``` +Let's now load `AdvancedVI`. +Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`. +Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. +Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. +```julia +using Optimisers +using ForwardDiff +import AdvancedVI as AVI +``` +We now need to select 1. a variational objective, and 2. a variational family. +Here, we will use the [ADVI objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. +```julia +prob = DynamicPPL.LogDensityFunction(model) +objective = AVI.ADVI(prob, b⁻¹, 10), +``` +For the variational family, we will use the classic mean-field Gaussian family. +```julia +d = LogDensityProblems.dimension(prob) μ = randn(d) L = Diagonal(ones(d)) -q = AVI.MeanFieldGaussian(μ, L) - +q = AVI.VIMeanFieldGaussian(μ, L) +``` +It now remains to run inverence! +``` n_max_iter = 10^4 -q, stats = AVI.optimize( - AVI.ADVI(prob, b⁻¹, 10), +q, stats = AVI.optimize( q, n_max_iter; adbackend = AutoForwardDiff(), From b8b113da2b3a64395e9daaf2bbb64e9b0b602a4e Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 21:14:17 +0100 Subject: [PATCH 075/206] update documentation --- docs/src/started.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/started.md b/docs/src/started.md index 26c75a79..355e9350 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -50,10 +50,11 @@ using ForwardDiff import AdvancedVI as AVI ``` We now need to select 1. a variational objective, and 2. a variational family. -Here, we will use the [ADVI objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. +Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```julia -prob = DynamicPPL.LogDensityFunction(model) -objective = AVI.ADVI(prob, b⁻¹, 10), +prob = DynamicPPL.LogDensityFunction(model)] +n_montecaro = 10 +objective = AVI.ADVI(prob, b⁻¹, n_montecaro), ``` For the variational family, we will use the classic mean-field Gaussian family. ```julia From b78e713eaf46d3caab540e1d818be9930bea54dc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 16 Aug 2023 23:35:23 +0100 Subject: [PATCH 076/206] fix type error in test --- test/distributions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions.jl b/test/distributions.jl index 073fff64..9b18d020 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -11,7 +11,7 @@ using Distributions: _logpdf seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) rng = Philox4x(UInt64, seed, 8) realtype = Float64 - ϵ = 1e-2 + ϵ = 1f-2 n_dims = 10 n_montecarlo = 1000_000 From a0564b56bbe86b5885c333aa7fe2ca0e48fa0b24 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 16 Aug 2023 23:35:29 +0100 Subject: [PATCH 077/206] remove default ADType argument --- Project.toml | 2 +- src/optimize.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index ffc41a4b..35650ae5 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" -DiffResults = "1.0.3" +DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10.25" diff --git a/src/optimize.jl b/src/optimize.jl index 0f2d29e9..93e6f754 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -31,6 +31,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `n_max_iter`: Maximum number of iterations. # Keyword Arguments +- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.) - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) @@ -47,11 +48,11 @@ function optimize( restructure, λ₀ ::AbstractVector{<:Real}, n_max_iter ::Int; + adbackend::AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::AbstractRNG = default_rng(), show_progress::Bool = true, callback! = nothing, - adbackend::AbstractADType = AutoForwardDiff(), prog = ProgressMeter.Progress( n_max_iter; desc = "Optimizing", From 3795d1e05f510887df1c2900ab9f7638797ecc87 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:01:52 +0100 Subject: [PATCH 078/206] update README --- README.md | 304 +++++++++++++++--------------------------------------- 1 file changed, 81 insertions(+), 223 deletions(-) diff --git a/README.md b/README.md index 18ba63e5..e8718e7c 100644 --- a/README.md +++ b/README.md @@ -1,250 +1,108 @@ -# AdvancedVI.jl -A library for variational Bayesian inference in Julia. - -At the time of writing (05/02/2020), implementations of the variational inference (VI) interface and some algorithms are implemented in [Turing.jl](https://github.com/TuringLang/Turing.jl). The idea is to soon separate the VI functionality in Turing.jl out and into this package. - -The purpose of this package will then be to provide a common interface together with implementations of standard algorithms and utilities with the goal of ease of use and the ability for other packages, e.g. Turing.jl, to write a light wrapper around AdvancedVI.jl for integration. -As an example, in Turing.jl we support automatic differentiation variational inference (ADVI) but really the only piece of code tied into the Turing.jl is the conversion of a `Turing.Model` to a `logjoint(z)` function which computes `z ↦ log p(x, z)`, with `x` denoting the observations embedded in the `Turing.Model`. As long as this `logjoint(z)` method is compatible with some AD framework, e.g. `ForwardDiff.jl` or `Zygote.jl`, this is all we need from Turing.jl to be able to perform ADVI! - -## [WIP] Interface -- `vi`: the main interface to the functionality in this package - - `vi(model, alg)`: only used when `alg` has a default variational posterior which it will provide. - - `vi(model, alg, q::VariationalPosterior, θ)`: `q` represents the family of variational distributions and `θ` is the initial parameters "indexing" the starting distribution. This assumes that there exists an implementation `Variational.update(q, θ)` which returns the variational posterior corresponding to parameters `θ`. - - `vi(model, alg, getq::Function, θ)`: here `getq(θ)` is a function returning a `VariationalPosterior` corresponding to `θ`. -- `optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())` -- `grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)` - - Different combinations of variational objectives (`vo`), VI methods (`alg`), and variational posteriors (`q`) might use different gradient estimators. `grad!` allows us to specify these different behaviors. +# AdvancedVI.jl +[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. +VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. +`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. +The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. +For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) by simply converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. ## Examples -### Variational Inference -A very simple generative model is the following - - μ ~ 𝒩(0, 1) - xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n - -where μ and xᵢ are some ℝᵈ vectors and 𝒩 denotes a d-dimensional multivariate Normal distribution. - -Given a set of `n` observations `[x₁, …, xₙ]` we're interested in finding the distribution `p(μ∣x₁, …, xₙ)` over the mean `μ`. We can obtain (an approximation to) this distribution that using AdvancedVI.jl! - -First we generate some observations and set up the problem: -```julia -julia> using Distributions - -julia> d = 2; n = 100; - -julia> observations = randn((d, n)); # 100 observations from 2D 𝒩(0, 1) - -julia> # Define generative model - # μ ~ 𝒩(0, 1) - # xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n - prior(μ) = logpdf(MvNormal(ones(d)), μ) -prior (generic function with 1 method) - -julia> likelihood(x, μ) = sum(logpdf(MvNormal(μ, ones(d)), x)) -likelihood (generic function with 1 method) - -julia> logπ(μ) = likelihood(observations, μ) + prior(μ) -logπ (generic function with 1 method) - -julia> logπ(randn(2)) # <= just checking that it works --311.74132761437653 -``` -Now there are mainly two different ways of specifying the approximate posterior (and its family). The first is by providing a mapping from distribution parameters to the distribution `θ ↦ q(⋅∣θ)`: -```julia -julia> using DistributionsAD, AdvancedVI - -julia> # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:d], exp.(θ[d + 1:4])) -getq (generic function with 1 method) -``` -Then we make the choice of algorithm, a subtype of `VariationalInference`, -```julia -julia> # Perform VI - advi = ADVI(10, 10_000) -ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 10000) -``` -And finally we can perform VI! The usual inferface is to call `vi` which behind the scenes takes care of the optimization and returns the resulting variational posterior: -```julia -julia> q = vi(logπ, advi, getq, randn(4)) -[ADVI] Optimizing...100% Time: 0:00:01 -TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[0.16282745378074515, 0.15789310089462574], σ=[0.09519377533754399, 0.09273176907111745]) -``` -Let's have a look at the resulting ELBO: -```julia -julia> AdvancedVI.elbo(advi, q, logπ, 1000) --287.7866366886285 -``` -Unfortunately, the *final* value of the ELBO is not always a very good diagnostic, though the ELBO is an important metric to keep an eye on during training since an *increase* in the ELBO means we're going in the right direction. Luckily, this is such a simple problem that we can indeed obtain a closed form solution! Because we're lazy (at least I am), we'll let [ConjugatePriors.jl](https://github.com/JuliaStats/ConjugatePriors.jl) do this for us: -```julia -julia> # True posterior - using ConjugatePriors -julia> pri = MvNormal(zeros(2), ones(2)); +`AdvancedVI` basically expects a `LogDensityProblem`. +For example, for the normal-log-normal model: +$$ +\begin{aligned} +x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) +\end{aligned} +$$ -julia> true_posterior = posterior((pri, pri.Σ), MvNormal, observations) -DiagNormal( -dim: 2 -μ: [0.1746546592601148, 0.16457110079543008] -Σ: [0.009900990099009901 0.0; 0.0 0.009900990099009901] -) +A `LogDensityProblem` can be implemented as ``` -Comparing to our variational approximation, this looks pretty good! Worth noting that in this particular case the variational posterior seems to overestimate the variance. +using LogDensityProblems -To conclude, let's make a somewhat pretty picture: -```julia -julia> using Plots - -julia> p_samples = rand(true_posterior, 10_000); q_samples = rand(q, 10_000); - -julia> p1 = histogram(p_samples[1, :], label="p"); histogram!(q_samples[1, :], alpha=0.7, label="q") - -julia> title!(raw"$\mu_1$") +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end -julia> p2 = histogram(p_samples[2, :], label="p"); histogram!(q_samples[2, :], alpha=0.7, label="q") +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end -julia> title!(raw"$\mu_2$") +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end -julia> plot(p1, p2) +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end ``` -![Histogram](hist.png?raw=true) - -### Simple example: using Advanced.jl to directly minimize the KL-divergence between two distributions `p(z)` and `q(z)` -In VI we aim to approximate the true posterior `p(z ∣ x)` by some approximate variational posterior `q(z)` by maximizing the ELBO: - - ELBO(q) = 𝔼_q[log p(x, z) - log q(z)] - -Observe that we can express the ELBO as the negative KL-divergence between `p(x, ⋅)` and `q(⋅)`: - - ELBO(q) = - 𝔼_q[log (q(z) / p(x, z))] - = - KL(q(⋅) || p(x, ⋅)) - -So if we apply VI to something that isn't an actual posterior, i.e. there's no data involved and we write `p(z ∣ x) = p(z)`, we're really just minimizing the KL-divergence between the distributions. - -Therefore, we can try out `AdvancedVI.jl` real quick by applying using the interface to minimize the KL-divergence between two distributions: +Since the support of `x` is constrained to be $$\mathbb{R}_+$$, and inference is best done in the unconstrained space $$\mathbb{R}_+$$, we need to use a *bijector* to match support. +This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, 2015). ```julia -julia> using Distributions, DistributionsAD, AdvancedVI - -julia> # Target distribution - p = MvNormal(ones(2)) -ZeroMeanDiagNormal( -dim: 2 -μ: [0.0, 0.0] -Σ: [1.0 0.0; 0.0 1.0] -) +using Bijectors -julia> logπ(z) = logpdf(p, z) -logπ (generic function with 1 method) - -julia> # Make a choice of VI algorithm - advi = ADVI(10, 1000) -ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 1000) -``` -Now there are two different ways of specifying the approximate posterior (and its family); the first is by providing a mapping from parameters to distribution `θ ↦ q(⋅∣θ)`: -```julia -julia> # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -getq (generic function with 1 method) - -julia> # Perform VI - q = vi(logπ, advi, getq, randn(4)) -┌ Info: [ADVI] Should only be seen once: optimizer created for θ -└ objectid(θ) = 0x5ddb564423896704 -[ADVI] Optimizing...100% Time: 0:00:01 -TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[-0.012691337868985757, -0.0004442434543332919], σ=[1.0334797673569802, 0.9957355128767893]) -``` -Or we can check the ELBO (which in this case since, as mentioned, doesn't involve data, is the negative KL-divergence): -```julia -julia> AdvancedVI.elbo(advi, q, logπ, 1000) # empirical estimate -0.08031049170093245 +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end ``` -It's worth noting that the actual value of the ELBO doesn't really tell us too much about the quality of fit. In this particular case, because we're *directly* minimizing the KL-divergence, we can only say something useful if we reach 0, in which case we have obtained the true distribution. -Let's just quickly check the mean-squared error between the `log p(z)` and `log q(z)` for a random set of samples from the target `p`: -```julia -julia> zs = rand(p, 100); +A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. -julia> mean(abs2, logpdf(q, zs) - logpdf(p, zs)) -0.0014889109427524852 +Let us instantiate a random normal-log-normal model. +```julia +using PDMats + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) ``` -That doesn't look too bad! - -## Implementing your own training loop -Sometimes it might be convenient to roll your own training loop rather than using `vi(...)`. Here's some psuedo-code for how one would do that when used together with Turing.jl: +ADVI can be used as follows: ```julia -using Turing, AdvancedVI, DiffResults -using Turing: Variational - -using ProgressMeter - -# Assuming you have an instance of a Turing model (`model`) - -# 1. Create log-joint needed for ELBO evaluation -logπ = Variational.make_logjoint(model) - -# 2. Define objective -variational_objective = Variational.ELBO() - -# 3. Optimizer -optimizer = Variational.DecayedADAGrad() - -# 4. VI-algorithm -alg = ADVI(10, 1000) - -# 5. Variational distribution -function getq(θ) - # ... -end - -# 6. [OPTIONAL] Implement convergence criterion -function hasconverged(args...) - # ... -end - -# 7. [OPTIONAL] Implement a callback for tracking stats -function callback(args...) - # ... -end - -# 8. Train -converged = false -step = 1 - -prog = ProgressMeter.Progress(num_steps, 1) - -diff_results = DiffResults.GradientResult(θ_init) - -while (step ≤ num_steps) && !converged - # 1. Compute gradient and objective value; results are stored in `diff_results` - AdvancedVI.grad!(variational_objective, alg, getq, model, diff_results) - - # 2. Extract gradient from `diff_result` - ∇ = DiffResults.gradient(diff_result) - - # 3. Apply optimizer, e.g. multiplying by step-size - Δ = apply!(optimizer, θ, ∇) - - # 4. Update parameters - @. θ = θ - Δ - - # 5. Do whatever analysis you want - callback(args...) - - # 6. Update - converged = hasconverged(...) # or something user-defined - step += 1 +using LinearAlgebra +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI + +b = Bijectors.bijector(model) +b⁻¹ = inverse(b) + +# ADVI objective +objective = AVI.ADVI(model, 10; b=b⁻¹) + +# Mean-field Gaussian variational family +d = LogDensityProblems.dimension(model) +μ = randn(d) +L = Diagonal(ones(d)) +q = AVI.VIMeanFieldGaussian(μ, L) + +# Run inference +n_max_iter = 10^4 +q, stats, _ = AVI.optimize( + objective, + q, + n_max_iter; + adbackend = ADTypes.AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +) - ProgressMeter.next!(prog) -end +# Evaluate final ELBO with 10^3 Monte Carlo samples +objective(q; n_samples=10^3) ``` ## References -- Jordan, Michael I., Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. "An introduction to variational methods for graphical models." Machine learning 37, no. 2 (1999): 183-233. -- Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. "Variational inference: A review for statisticians." Journal of the American statistical Association 112, no. 518 (2017): 859-877. - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015. -- Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882. -- Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003. From 28a35bcd0ce6bd4489915ae1cf37801db211b2ec Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:02:04 +0100 Subject: [PATCH 079/206] update make getting started example actually run Julia --- docs/Project.toml | 13 ++++- docs/src/started.md | 115 +++++++++++++++++++++++++++++++++----------- 2 files changed, 98 insertions(+), 30 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index c625d07f..182edd3e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,7 +1,18 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [compat] -Documenter = "0.26, 0.27" \ No newline at end of file +ADTypes = "0.1.6" +Bijectors = "0.13.6" +Documenter = "0.26, 0.27" +LogDensityProblems = "2.1.1" diff --git a/docs/src/started.md b/docs/src/started.md index 355e9350..fec60f1a 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -14,62 +14,119 @@ To use `AdvancedVI`, a user needs to select a `variational family`, `variational optimize ``` -## `ADVI` Example Using `Turing` +## `ADVI` Example +In this tutorial, we will work with a basic `normal-log-normal` model. +```math +\begin{aligned} +x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) +\end{aligned} +``` +ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. -In this tutorial, we'll use `Turing` to define a basic `normal-log-normal` model. -ADVI with log bijectors is able to infer this model exactly. -```julia -using Turing +Using the `LogDensityProblems` interface, we the model can be defined as follows: +```@example advi +using LogDensityProblems +using SimpleUnPack -μ_y, σ_y = 1.0, 1.0 -μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0] +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end -Turing.@model function normallognormal() - y ~ LogNormal(μ_y, σ_y) - z ~ MvNormal(μ_z, Σ_z) +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end -model = normallognormal() + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end +``` +Let's now instantiate the model +```@example advi +using PDMats + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. -```julia +```@example advi using Bijectors -b = Bijectors.bijector(model) -b⁻¹ = inverse(b) +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +b = Bijectors.bijector(model); +b⁻¹ = inverse(b) ``` Let's now load `AdvancedVI`. Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`. Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. -```julia +```@example advi using Optimisers -using ForwardDiff +using ADTypes, ForwardDiff import AdvancedVI as AVI ``` We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. -```julia -prob = DynamicPPL.LogDensityFunction(model)] -n_montecaro = 10 -objective = AVI.ADVI(prob, b⁻¹, n_montecaro), +```@example advi +n_montecaro = 10; +objective = AVI.ADVI(model, n_montecaro; b = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. -```julia -d = LogDensityProblems.dimension(prob) -μ = randn(d) -L = Diagonal(ones(d)) +```@example advi +using LinearAlgebra + +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); q = AVI.VIMeanFieldGaussian(μ, L) ``` -It now remains to run inverence! -``` -n_max_iter = 10^4 -q, stats = AVI.optimize( +Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. +```@example advi +n_max_iter = 10^4 +q, stats, _ = AVI.optimize( + objective, q, n_max_iter; adbackend = AutoForwardDiff(), optimizer = Optimisers.Adam(1e-3) -) +); +``` + +The selected inference procedure stores per-iteration statistics into `stats`. +For instance, the ELBO can be ploted as follows: +```@example advi +using Plots + +t = [stat.iteration for stat ∈ stats] +y = [stat.elbo for stat ∈ stats] +plot(t[1:100:end], y[1:100:end]) +savefig("advi_example_elbo.svg"); nothing +``` +![](advi_example_elbo.svg) +Further information can be gathered by defining your own `callback!`. + +The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: +```@example advi +ELBO = objective(q; n_samples=10^4) ``` From 620b38e7d345c60d59c08174144f1349618ff60c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:02:16 +0100 Subject: [PATCH 080/206] fix remove Float32 tests for inference tests --- ext/AdvancedVIForwardDiffExt.jl | 2 +- test/advi_locscale.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index e6b03af2..5949bdf8 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -11,8 +11,8 @@ else using ..AdvancedVI: ADTypes, DiffResults end -# extract chunk size from AutoForwardDiff getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize + function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult ) where {T<:Real} diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index d4ef7aec..e4c81402 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -25,7 +25,7 @@ include("models/utils.jl") @testset "advi" begin @testset "locscale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float32, Float64], + realtype ∈ [Float64], # Currently only tested against Float64 (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, From fa533981d6c3208e008d04f35a18ec08728ca608 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:54:13 +0100 Subject: [PATCH 081/206] update version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 35650ae5..2092b0cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.2.4" +version = "0.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From e909f4106e919e2d834a4f73eac3ca929bd5b9dd Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 20:04:34 +0100 Subject: [PATCH 082/206] add documentation publishing url --- docs/make.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index ca21b5fd..5d371608 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,15 +7,16 @@ DocMeta.setdocmeta!( ) makedocs(; - sitename = "AdvancedVI.jl", modules = [AdvancedVI], + sitename = "AdvancedVI.jl", + repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["AdvancedVI" => "index.md", - "Getting Started" => "started.md", - "ELBO Maximization" => [ - "Automatic Differentiation VI" => "advi.md", - "Location Scale Family" => "locscale.md", - ]], + pages = ["AdvancedVI" => "index.md", + "Getting Started" => "started.md", + "ELBO Maximization" => [ + "Automatic Differentiation VI" => "advi.md", + "Location Scale Family" => "locscale.md", + ]], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) From 43f5b751abb963533cbb6835ca6c8315a53a41d2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 20:17:04 +0100 Subject: [PATCH 083/206] fix wrong uuid for ForwardDiff --- src/AdvancedVI.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1677be62..c45d4997 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -106,7 +106,7 @@ function __init__() @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/AdvancedVIEnzymeExt.jl") end - @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/AdvancedVIForwardDiffExt.jl") end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin From 468d5ca3aa94f7c83287633beba23aa5d174ca88 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 17 Aug 2023 21:44:15 +0100 Subject: [PATCH 084/206] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 158da963..26f6876f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: - windows-latest arch: - x64 - - x86 + # - x86 # Uncomment after https://github.com/JuliaTesting/ReTest.jl/pull/52 is merged exclude: - os: macOS-latest arch: x86 From c07a5118a237fd5eb3a478a88fdcefe06673b366 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 21:49:26 +0100 Subject: [PATCH 085/206] refactor use `sum` and `mean` instead of abusing `mapreduce` --- src/distributions/location_scale.jl | 4 ++-- src/objectives/elbo/advi.jl | 5 ++--- src/objectives/elbo/entropy.jl | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index e901e8de..3113c679 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -42,12 +42,12 @@ end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) + sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) + sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 8bc14bc9..67af4375 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -55,10 +55,9 @@ function (advi::ADVI)( q_η::ContinuousMultivariateDistribution, ηs ::AbstractMatrix ) - n_samples = size(ηs, 2) - 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ + 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) - (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples + (advi.ℓπ(zᵢ) + logdetjacᵢ) end ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 0edc47f4..694eacef 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -23,9 +23,8 @@ The "sticking the landing" entropy estimator. struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) - n_samples = size(ηs, 2) - mapreduce(+, eachcol(ηs)) do ηᵢ - -logpdf(q, ηᵢ) / n_samples + mean(eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) end end From 13a8a445af64690b61137f6791f4f11eb6130a2b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 22:14:42 +0100 Subject: [PATCH 086/206] remove tests for `FullMonteCarlo` --- test/advi_locscale.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index e4c81402..962d3169 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -35,7 +35,6 @@ include("models/utils.jl") (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), - :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = FullMonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), From aadf8d397aad300b6e5d502b8a90bd0f2724d778 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 01:31:58 +0100 Subject: [PATCH 087/206] add tests for the `optimize` interface --- test/advi_locscale.jl | 4 +-- test/optimize.jl | 84 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 test/optimize.jl diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 962d3169..bf51199f 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -38,8 +38,8 @@ include("models/utils.jl") ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - # :ReverseDiff => AutoReverseDiff(), - # :Zygote => AutoZygote(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), # :Enzyme => AutoEnzyme(), ) diff --git a/test/optimize.jl b/test/optimize.jl new file mode 100644 index 00000000..3ece467f --- /dev/null +++ b/test/optimize.jl @@ -0,0 +1,84 @@ + +using ReTest +using Bijectors +using LogDensityProblems +using Optimisers +using Distributions +using PDMats +using LinearAlgebra +using SimpleUnPack: @unpack + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("models/normallognormal.jl") +include("models/utils.jl") + +@testset "optimize" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + T = 1000 + modelstats = normallognormal_meanfield(Float64; rng) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + # Global Test Configurations + b⁻¹ = Bijectors.bijector(model) |> inverse + μ₀ = zeros(Float64, n_dims) + L₀ = ones(Float64, n_dims) |> Diagonal + q₀ = VIMeanFieldGaussian(μ₀, L₀) + obj = ADVI(model, 10; b=b⁻¹) + + adbackend = AutoForwardDiff() + optimizer = Optimisers.Adam(1e-2) + + rng = Philox4x(UInt64, seed, 8) + q_ref, stats_ref, _ = optimize( + obj, q₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + ) + λ_ref, _ = Optimisers.destructure(q_ref) + + @testset "restructure" begin + λ₀, re = Optimisers.destructure(q₀) + + rng = Philox4x(UInt64, seed, 8) + λ, stats, _ = optimize( + obj, re, λ₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + ) + @test λ == λ_ref + @test stats == stats_ref + end + + @testset "callback" begin + rng = Philox4x(UInt64, seed, 8) + test_values = rand(rng, T) + + callback!(; stat, est_state, restructure, λ) = begin + (test_value = test_values[stat.iteration],) + end + + rng = Philox4x(UInt64, seed, 8) + _, stats, _ = optimize( + obj, q₀, T; + show_progress = false, + rng, + adbackend, + callback! + ) + @test [stat.test_value for stat ∈ stats] == test_values + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 68225fd9..6bd3bc49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,11 +8,13 @@ using Random123 using Statistics using Distributions using LinearAlgebra + using AdvancedVI include("ad.jl") include("distributions.jl") include("advi_locscale.jl") +include("optimize.jl") @main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) From 8c4e13db72524ad31bf6306219436d3b78320237 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 01:33:05 +0100 Subject: [PATCH 088/206] fix turn off Zygote tests for now --- test/advi_locscale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index bf51199f..e8b4be03 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -39,7 +39,7 @@ include("models/utils.jl") (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), + # :Zygote => AutoZygote(), # :Enzyme => AutoEnzyme(), ) From 0b708e6297d781722a582058d42f7e0917cf49bd Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 03:09:11 +0100 Subject: [PATCH 089/206] remove unused function --- src/objectives/elbo/entropy.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 694eacef..022ed4f6 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,8 +5,6 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end -skip_entropy_gradient(::ClosedFormEntropy) = false - abstract type MonteCarloEntropy <: AbstractEntropyEstimator end struct FullMonteCarloEntropy <: MonteCarloEntropy end From be61acd46d457206cbd07386377958d5afb178e3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 03:51:34 +0100 Subject: [PATCH 090/206] refactor change bijector field name, simplify STL estimator --- Project.toml | 2 ++ src/AdvancedVI.jl | 4 +-- src/objectives/elbo/advi.jl | 46 +++++++--------------------------- src/objectives/elbo/entropy.jl | 15 ++++++----- test/advi_locscale.jl | 4 +-- test/optimize.jl | 2 +- 6 files changed, 25 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index 2092b0cb..e099308a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -37,6 +38,7 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" +ChainRules = "1.53.0" DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c45d4997..cca220f1 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -23,7 +23,7 @@ using LogDensityProblems using ADTypes, DiffResults using ADTypes: AbstractADType - +using ChainRules: @ignore_derivatives using FillArrays using PDMats @@ -74,7 +74,7 @@ export ADVI, ClosedFormEntropy, StickingTheLandingEntropy, - FullMonteCarloEntropy + MonteCarloEntropy # Variational Families diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 67af4375..788449d1 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -11,7 +11,7 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) # Keyword Arguments - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) - `cv`: A control variate. -- `b`: A bijector mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) +- `invbij`: A bijective mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) # Requirements - ``q_{\\lambda}`` implements `rand`. @@ -23,7 +23,7 @@ struct ADVI{Tlogπ, B, EntropyEst <: AbstractEntropyEstimator, ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective ℓπ::Tlogπ - b::B + invbij::B entropy::EntropyEst cv::ControlVar n_samples::Int @@ -31,7 +31,7 @@ struct ADVI{Tlogπ, B, function ADVI(prob, n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy(), cv::Union{<:AbstractControlVariate, Nothing} = nothing, - b = Bijectors.identity) + invbij = Bijectors.identity) cap = LogDensityProblems.capabilities(prob) if cap === nothing throw( @@ -41,7 +41,7 @@ struct ADVI{Tlogπ, B, ) end ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b), typeof(entropy), typeof(cv)}(ℓπ, b, entropy, cv, n_samples) + new{typeof(ℓπ), typeof(invbij), typeof(entropy), typeof(cv)}(ℓπ, invbij, entropy, cv, n_samples) end end @@ -56,7 +56,7 @@ function (advi::ADVI)( ηs ::AbstractMatrix ) 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) (advi.ℓπ(zᵢ) + logdetjacᵢ) end ℍ = advi.entropy(q_η, ηs) @@ -86,50 +86,22 @@ function (advi::ADVI)( advi(rng, q_η, ηs) end -function estimate_advi_gradient_maybe_stl!( - rng::AbstractRNG, - adbackend::AbstractADType, - advi::ADVI{P, B, StickingTheLandingEntropy, CV}, - λ::Vector{<:Real}, - restructure, - out::DiffResults.MutableDiffResult -) where {P, B, CV} - q_η_stop = restructure(λ) - f(λ′) = begin - q_η = restructure(λ′) - ηs = rand(rng, q_η, advi.n_samples) - -advi(rng, q_η_stop, ηs) - end - value_and_gradient!(adbackend, f, λ, out) -end - -function estimate_advi_gradient_maybe_stl!( +function estimate_gradient( rng::AbstractRNG, adbackend::AbstractADType, - advi::ADVI{P, B, <:Union{ClosedFormEntropy, FullMonteCarloEntropy}, CV}, + advi::ADVI, + est_state, λ::Vector{<:Real}, restructure, out::DiffResults.MutableDiffResult -) where {P, B, CV} +) f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η, ηs) end value_and_gradient!(adbackend, f, λ, out) -end -function estimate_gradient( - rng::AbstractRNG, - adbackend::AbstractADType, - advi::ADVI, - est_state, - λ::Vector{<:Real}, - restructure, - out::DiffResults.MutableDiffResult -) - estimate_advi_gradient_maybe_stl!( - rng, adbackend, advi, λ, restructure, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 022ed4f6..97ccda29 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,9 +5,13 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end -abstract type MonteCarloEntropy <: AbstractEntropyEstimator end +struct MonteCarloEntropy <: AbstractEntropyEstimator end -struct FullMonteCarloEntropy <: MonteCarloEntropy end +function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) + mean(eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) + end +end """ StickingTheLandingEntropy() @@ -18,11 +22,10 @@ The "sticking the landing" entropy estimator. - `q` implements `logpdf`. - `logpdf(q, η)` must be differentiable by the selected AD framework. """ -struct StickingTheLandingEntropy <: MonteCarloEntropy end +struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) - mean(eachcol(ηs)) do ηᵢ +function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix) + @ignore_derivatives mean(eachcol(ηs)) do ηᵢ -logpdf(q, ηᵢ) end end - diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index e8b4be03..71cf22d5 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -33,8 +33,8 @@ include("models/utils.jl") :NormalFullRank => normal_fullrank, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), - :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹, entropy = StickingTheLandingEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/optimize.jl b/test/optimize.jl index 3ece467f..d514d236 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -33,7 +33,7 @@ include("models/utils.jl") μ₀ = zeros(Float64, n_dims) L₀ = ones(Float64, n_dims) |> Diagonal q₀ = VIMeanFieldGaussian(μ₀, L₀) - obj = ADVI(model, 10; b=b⁻¹) + obj = ADVI(model, 10; invbij=b⁻¹) adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) From fb519a501585fd279a62bce331ea81b19627ba06 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 03:51:59 +0100 Subject: [PATCH 091/206] update documentation --- docs/src/advi.md | 177 +++++++++++++++++++++++++++++++++++++++++--- docs/src/started.md | 8 +- 2 files changed, 170 insertions(+), 15 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 37b3541b..3719c89e 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -66,34 +66,187 @@ ADVI The STL control variate was proposed by Roeder *et al.* (2017). By slightly modifying the differentiation path, it implicitly forms a control variate of the form of ```math -\mathrm{CV}_{\mathrm{STL}}\left(z\right) \triangleq \mathbb{H}\left(q_{\lambda}\right) + \log q_{\lambda}\left(z\right), +\begin{aligned} + \mathrm{CV}_{\mathrm{STL}}\left(z\right) + &\triangleq + \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) \\ + &= + -\nabla_{\lambda} \mathbb{E}_{z \sim q_{\nu}} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) +\end{aligned} ``` -which has a mean of zero. +where ``\nu = \lambda`` is set to avoid differentiating through the density of ``q_{\lambda}``. +We can see that this vector-valued function has a mean of zero and is therefore a valid control variate. Adding this to the closed-form entropy ELBO estimator yields the STL estimator: ```math \begin{aligned} - \widehat{\mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) - &\triangleq \mathbb{E}\left[ \log \pi \left(z\right) \right] - \log q_{\lambda} \left(z\right) \\ - &= \mathbb{E}\left[ \log \pi\left(z\right) \right] - + \mathbb{H}\left(q_{\lambda}\right) - \mathrm{CV}_{\mathrm{STL}}\left(z\right) \\ - &= \widehat{\mathrm{ELBO}}\left(\lambda\right) - - \mathrm{CV}_{\mathrm{STL}}\left(z\right), + \widehat{\nabla \mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) + &\triangleq \mathbb{E}_{u \sim \varphi}\left[ + \nabla_{\lambda} \log \pi \left(z_{\lambda}\left(u\right)\right) + - + \nabla_{\lambda} \log q_{\nu} \left(z_{\lambda}\left(u\right)\right) + \right] + \\ + &= + \mathbb{E}\left[ \nabla_{\lambda} \log \pi\left(z_{\lambda}\left(u\right)\right) \right] + + + \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + - + \mathrm{CV}_{\mathrm{STL}}\left(z\right) + \\ + &= + \widehat{\nabla \mathrm{ELBO}}\left(\lambda\right) + - + \mathrm{CV}_{\mathrm{STL}}\left(z\right), \end{aligned} ``` -which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. +which has the same expectation as the original ADVI estimator, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. The conditions for which the STL estimator results in lower variance is still an active subject for research. +The main downside of the STL estimator is that it needs to evaluate and differentiate the log density of ``q_{\lambda}`` in every iteration. +Depending on the variational family, this might be computationally inefficient or even numerically unstable. +For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability. + + The STL control variate can be used by changing the entropy estimator using the following object: ```@docs StickingTheLandingEntropy ``` -For example: -```julia -ADVI(prob, n_samples; entropy = StickingTheLandingEntropy(), b = bijector) +```@setup stl +using LogDensityProblems +using SimpleUnPack +using PDMats +using Bijectors +using LinearAlgebra +using Plots + +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); + +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); +q0 = AVI.VIMeanFieldGaussian(μ, L) + +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end ``` +Let us come back to the example in [Getting Started](@ref getting_started), where a `LogDensityProblem` is given as `model`. +In this example, the true posterior is contained within the variational family. +This setting is known as "perfect variational family specification." +In this case, the STL estimator is able to converge exponentially fast to the true solution. + +Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows: +```@example stl +n_montecarlo = 1; +b = Bijectors.bijector(model); +b⁻¹ = inverse(b) + +cfe = AVI.ADVI(model, n_montecarlo; invbij = b⁻¹) +``` +The STL estimator can instead be created as follows: +```@example stl +stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), invbij = b⁻¹); +``` + +```@setup stl +n_max_iter = 10^4 + +idx = [1] +callback!(; stat, est_state, restructure, λ) = begin + if mod(idx[1], 100) == 1 + idx[:] .+= 1 + (elbo_accurate = cfe(restructure(λ); n_samples=10^4),) + else + idx[:] .+= 1 + NamedTuple() + end +end + +_, stats_cfe, _ = AVI.optimize( + cfe, + q0, + n_max_iter; + show_progress = false, + callback! = callback!, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +idx[:] .= 1 +_, stats_stl, _ = AVI.optimize( + stl, + q0, + n_max_iter; + show_progress = false, + callback! = callback!, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +fmc = AVI.ADVI(model, n_montecarlo; entropy = AVI.MonteCarloEntropy(), invbij = b⁻¹) +idx[:] .= 1 +_, stats_fmc, _ = AVI.optimize( + fmc, + q0, + n_max_iter; + show_progress = false, + callback! = callback!, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +t = [stat.iteration for stat ∈ stats_cfe[1:100:end]] +y_cfe = [stat.elbo_accurate for stat ∈ stats_cfe[1:100:end]] +y_stl = [stat.elbo_accurate for stat ∈ stats_stl[1:100:end]] +y_fmc = [stat.elbo_accurate for stat ∈ stats_fmc[1:100:end]] +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +plot!(t, y_fmc, label="ADVI FMC", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +savefig("advi_stl_elbo.svg") +nothing +``` +![](advi_stl_elbo.svg) + +We can see that the noise of the STL estimator converges to a more accurate solution compared to the CFE estimator. + + ## References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. diff --git a/docs/src/started.md b/docs/src/started.md index fec60f1a..b89a140a 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -90,7 +90,7 @@ We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```@example advi n_montecaro = 10; -objective = AVI.ADVI(model, n_montecaro; b = b⁻¹) +objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi @@ -120,10 +120,12 @@ using Plots t = [stat.iteration for stat ∈ stats] y = [stat.elbo for stat ∈ stats] -plot(t[1:100:end], y[1:100:end]) -savefig("advi_example_elbo.svg"); nothing +plot(t, y, label="ADVI", xlabel="Iteration", ylabel="ELBO") +savefig("advi_example_elbo.svg") +nothing ``` ![](advi_example_elbo.svg) + Further information can be gathered by defining your own `callback!`. The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: From 8682fd92d7746e3f6741bbcb2f2029b12653ba72 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 04:00:17 +0100 Subject: [PATCH 092/206] update STL documentation --- docs/src/advi.md | 42 ++++++++---------------------------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 3719c89e..0d5b9568 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -63,6 +63,7 @@ ADVI ``` ## The `StickingTheLanding` Control Variate + The STL control variate was proposed by Roeder *et al.* (2017). By slightly modifying the differentiation path, it implicitly forms a control variate of the form of ```math @@ -188,63 +189,36 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i ```@setup stl n_max_iter = 10^4 -idx = [1] -callback!(; stat, est_state, restructure, λ) = begin - if mod(idx[1], 100) == 1 - idx[:] .+= 1 - (elbo_accurate = cfe(restructure(λ); n_samples=10^4),) - else - idx[:] .+= 1 - NamedTuple() - end -end - _, stats_cfe, _ = AVI.optimize( cfe, q0, n_max_iter; show_progress = false, - callback! = callback!, adbackend = AutoForwardDiff(), optimizer = Optimisers.Adam(1e-3) ); -idx[:] .= 1 _, stats_stl, _ = AVI.optimize( stl, q0, n_max_iter; show_progress = false, - callback! = callback!, - adbackend = AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) -); - -fmc = AVI.ADVI(model, n_montecarlo; entropy = AVI.MonteCarloEntropy(), invbij = b⁻¹) -idx[:] .= 1 -_, stats_fmc, _ = AVI.optimize( - fmc, - q0, - n_max_iter; - show_progress = false, - callback! = callback!, adbackend = AutoForwardDiff(), optimizer = Optimisers.Adam(1e-3) ); -t = [stat.iteration for stat ∈ stats_cfe[1:100:end]] -y_cfe = [stat.elbo_accurate for stat ∈ stats_cfe[1:100:end]] -y_stl = [stat.elbo_accurate for stat ∈ stats_stl[1:100:end]] -y_fmc = [stat.elbo_accurate for stat ∈ stats_fmc[1:100:end]] -plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) -plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) -plot!(t, y_fmc, label="ADVI FMC", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +t = [stat.iteration for stat ∈ stats_cfe] +y_cfe = [stat.elbo for stat ∈ stats_cfe] +y_stl = [stat.elbo for stat ∈ stats_stl] +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO") +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO") savefig("advi_stl_elbo.svg") nothing ``` ![](advi_stl_elbo.svg) -We can see that the noise of the STL estimator converges to a more accurate solution compared to the CFE estimator. +We can see that the noise of the STL estimator becomes smaller as VI converges. +However, the difference in speed of convergence may not always be significant. ## References From 9a16ee109a8b095e36d700389b18a35dc1355c2c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 04:01:48 +0100 Subject: [PATCH 093/206] update STL documentation --- docs/src/advi.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 0d5b9568..afb780cb 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -218,7 +218,7 @@ nothing ![](advi_stl_elbo.svg) We can see that the noise of the STL estimator becomes smaller as VI converges. -However, the difference in speed of convergence may not always be significant. +However, the speed of convergence may not always be significantly different. ## References From fc74afaef98e8c31ed04c55abdce20d25a644e4d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 04:03:33 +0100 Subject: [PATCH 094/206] update location scale documentation --- docs/src/locscale.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/locscale.md b/docs/src/locscale.md index a4bc2dc1..63ff5cb4 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -10,6 +10,7 @@ z \stackrel{d}{=} C u + m;\quad u \sim \varphi where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. ``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. + The probability density is given by ```math q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) @@ -19,6 +20,8 @@ and the entropy is given as \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|, ``` where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. +Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. +The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. ## Constructors From 4be30a1a44c70b4e9356768fd2d8ac662e7bc461 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 20 Aug 2023 00:10:48 +0100 Subject: [PATCH 095/206] fix README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e8718e7c..c43748e5 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,14 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije `AdvancedVI` basically expects a `LogDensityProblem`. For example, for the normal-log-normal model: $$ -\begin{aligned} +\begin{align*} x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) -\end{aligned} -$$ +\end{align*} +$$ A `LogDensityProblem` can be implemented as -``` +```julia using LogDensityProblems struct NormalLogNormal{MX,SX,MY,SY} From c58309dbaea25c986074b69a02f5bc6035dfcde8 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 20 Aug 2023 00:12:15 +0100 Subject: [PATCH 096/206] fix math in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c43748e5..8def2d98 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije `AdvancedVI` basically expects a `LogDensityProblem`. For example, for the normal-log-normal model: + $$ \begin{align*} x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ From 5b5bd3e9c3f4e90ac0d34f789b17c43c199ebd7d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 20 Aug 2023 03:08:20 +0100 Subject: [PATCH 097/206] add gradient to arguments of callback!, remove `gradient_norm` info --- src/objectives/elbo/advi.jl | 2 +- src/optimize.jl | 8 ++++---- test/optimize.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 788449d1..d8719fa7 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -57,7 +57,7 @@ function (advi::ADVI)( ) 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) - (advi.ℓπ(zᵢ) + logdetjacᵢ) + advi.ℓπ(zᵢ) + logdetjacᵢ end ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ diff --git a/src/optimize.jl b/src/optimize.jl index 93e6f754..43b06689 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -26,7 +26,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie # Arguments - `objective`: Variational Objective. - `λ₀`: Initial value of the variational parameters. -- `restructure`: Function that reconstructs the variational approximation from the flattened parameters. +- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. - `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. - `n_max_iter`: Maximum number of iterations. @@ -35,7 +35,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) +- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) # Returns @@ -76,11 +76,11 @@ function optimize( g = DiffResults.gradient(grad_buf) opt_state, λ = Optimisers.update!(opt_state, λ, g) - stat′ = (iteration=t, gradient_norm=norm(g)) + stat′ = (iteration = t,) stat = merge(stat, stat′) if !isnothing(callback!) - stat′ = callback!(; est_state, stat, restructure, λ) + stat′ = callback!(; est_state, stat, λ, g, restructure) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end diff --git a/test/optimize.jl b/test/optimize.jl index d514d236..920a3070 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -67,7 +67,7 @@ include("models/utils.jl") rng = Philox4x(UInt64, seed, 8) test_values = rand(rng, T) - callback!(; stat, est_state, restructure, λ) = begin + callback!(; stat, est_state, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end From 967021d2a1aa827d9dedda00c2b3eae39638986e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:43:43 +0100 Subject: [PATCH 098/206] fix math in README.md Co-authored-by: David Widmann --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8def2d98..83c2e8bc 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije For example, for the normal-log-normal model: $$ -\begin{align*} -x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) -\end{align*} +\begin{aligned} +x &\sim \operatorname{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) +\end{aligned} $$ A `LogDensityProblem` can be implemented as From 4dab522ff2583f7a622f7c6d35f829f8daf37cf2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:44:16 +0100 Subject: [PATCH 099/206] fix type constraint in `ZygoteExt` Co-authored-by: David Widmann --- ext/AdvancedVIZygoteExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index b447d071..c3d891bb 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -12,10 +12,10 @@ else end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) y, back = Zygote.pullback(f, θ) - ∇θ = back(one(T)) + ∇θ = back(one(y)) DiffResults.value!(out, y) DiffResults.gradient!(out, first(∇θ)) return out From 8ab2f19d208d82d720f462107b4949a16bfa3513 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:44:58 +0100 Subject: [PATCH 100/206] fix import of `Random` Co-authored-by: David Widmann --- src/AdvancedVI.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index cca220f1..a314e992 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,7 +4,7 @@ module AdvancedVI using SimpleUnPack: @unpack, @pack! using Accessors -import Random: AbstractRNG, default_rng +using Random: AbstractRNG, default_rng using Distributions import Distributions: logpdf, _logpdf, rand, _rand!, _rand!, From 83dec9fdc25226ed2dff13cc576981f09351a229 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:46:08 +0100 Subject: [PATCH 101/206] refactor `__init__()` Co-authored-by: David Widmann --- src/AdvancedVI.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a314e992..348a6a30 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -101,8 +101,8 @@ if !isdefined(Base, :get_extension) # check whether :get_extension is defined in using Requires end -function __init__() - @static if !isdefined(Base, :get_extension) +@static if !isdefined(Base, :get_extension) + function __init__() @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/AdvancedVIEnzymeExt.jl") end From a3e563cd43d937602e01f36e87247068f2a0b4ab Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:47:08 +0100 Subject: [PATCH 102/206] fix type constraint in definition of `value_and_gradient!` Co-authored-by: David Widmann --- src/AdvancedVI.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 348a6a30..42cd0dc5 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -39,9 +39,9 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) value_and_gradient!( ad::ADTypes.AbstractADType, f, - θ::AbstractVector{T}, + θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult - ) where {T<:Real} + ) Compute the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. The result is stored in `out`. From 5553bb950840ea9b8c6aba7794f52d58d3fce910 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:52:56 +0100 Subject: [PATCH 103/206] refactor `ZygoteExt`; use `only` instead of `first` Co-authored-by: David Widmann --- ext/AdvancedVIZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index c3d891bb..7b8f8817 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -17,7 +17,7 @@ function AdvancedVI.value_and_gradient!( y, back = Zygote.pullback(f, θ) ∇θ = back(one(y)) DiffResults.value!(out, y) - DiffResults.gradient!(out, first(∇θ)) + DiffResults.gradient!(out, only(∇θ)) return out end From 79b455746860f7957a7703d8d99fcbd79e613409 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:53:38 +0100 Subject: [PATCH 104/206] refactor type constraint in `ReverseDiffExt` Co-authored-by: David Widmann --- ext/AdvancedVIReverseDiffExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index fd7fbaab..520cd9ff 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,8 +13,8 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) tp = ReverseDiff.GradientTape(f, θ) ReverseDiff.gradient!(out, tp, θ) return out From 656b44b03f86ea83cba1d8de3953db956ffbe0ab Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 21 Aug 2023 23:56:28 +0100 Subject: [PATCH 105/206] refactor remove outdated debug mode macro --- src/AdvancedVI.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 42cd0dc5..ae0dc684 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -32,8 +32,6 @@ using Bijectors using StatsBase using StatsBase: entropy -const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) - # derivatives """ value_and_gradient!( From c7940636a8e08f5a97740f9872c4cffb4e6bed4d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 00:10:00 +0100 Subject: [PATCH 106/206] fix remove outdated DEBUG mechanism --- src/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 43b06689..57ee8030 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -84,7 +84,7 @@ function optimize( stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - AdvancedVI.DEBUG && @debug "Step $t" stat... + @debug "Iteration $t" stat... pm_next!(prog, stat) push!(stats, stat) From 0c5cc1ce8eacc3451bf360bb3c1b0301415242d4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 00:13:43 +0100 Subject: [PATCH 107/206] fix LaTeX in README: `operatorname` is currently broken --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 83c2e8bc..b3538ccf 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ For example, for the normal-log-normal model: $$ \begin{aligned} -x &\sim \operatorname{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +x &\sim \mathrm{Log\text{-}Normal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} $$ From 29d7d27ca227413275174e12f9258b13b8276fd0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:04:43 +0100 Subject: [PATCH 108/206] remove `SimpleUnPack` dependency --- Project.toml | 1 - docs/Project.toml | 1 - docs/src/advi.md | 9 +++------ docs/src/started.md | 11 ++++------- src/AdvancedVI.jl | 1 - src/distributions/location_scale.jl | 14 +++++++------- test/Project.toml | 1 - test/advi_locscale.jl | 3 +-- test/models/normal.jl | 2 +- test/models/normallognormal.jl | 7 ++++--- test/optimize.jl | 1 - 11 files changed, 20 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index e099308a..29cc559f 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/docs/Project.toml b/docs/Project.toml index 182edd3e..1f4ba59f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,7 +9,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [compat] ADTypes = "0.1.6" diff --git a/docs/src/advi.md b/docs/src/advi.md index afb780cb..88c11fee 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -116,7 +116,6 @@ StickingTheLandingEntropy ```@setup stl using LogDensityProblems -using SimpleUnPack using PDMats using Bijectors using LinearAlgebra @@ -134,7 +133,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -151,17 +150,15 @@ n_dims = 10 σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); q0 = AVI.VIMeanFieldGaussian(μ, L) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); - function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) diff --git a/docs/src/started.md b/docs/src/started.md index b89a140a..4a1d26ec 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -27,7 +27,6 @@ ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. Using the `LogDensityProblems` interface, we the model can be defined as follows: ```@example advi using LogDensityProblems -using SimpleUnPack struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -37,7 +36,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -51,14 +50,14 @@ end ``` Let's now instantiate the model ```@example advi -using PDMats +using LinearAlgebra n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. @@ -67,7 +66,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat using Bijectors function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -94,8 +93,6 @@ objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi -using LinearAlgebra - d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index ae0dc684..5d0c3f8d 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,6 @@ module AdvancedVI -using SimpleUnPack: @unpack, @pack! using Accessors using Random: AbstractRNG, default_rng diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 3113c679..73be42b9 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -35,42 +35,42 @@ Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) function StatsBase.entropy(q::VILocationScale) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) n_dims*entropy(dist) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale*rand(dist, n_dims) + location end function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale*rand(rng, dist, n_dims, num_samples) .+ location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q rand!(rng, dist, x) x .= scale*x return x += location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q rand!(rng, dist, x) x *= scale return x += location diff --git a/test/Project.toml b/test/Project.toml index 2f38c88f..277b73c8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 71cf22d5..c6aee68b 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -8,7 +8,6 @@ using Optimisers using Distributions using PDMats using LinearAlgebra -using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M @@ -48,7 +47,7 @@ include("models/utils.jl") T = 10000 modelstats = modelconstr(realtype; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/models/normal.jl b/test/models/normal.jl index f60ad5f3..1dfa653c 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -5,7 +5,7 @@ struct TestMvNormal{M,S} end function LogDensityProblems.logdensity(model::TestMvNormal, θ) - @unpack μ, Σ = model + (; μ, Σ) = model logpdf(MvNormal(μ, Σ), θ) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index cab73cce..49da5bf6 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) end function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -56,7 +56,8 @@ function normallognormal_meanfield(realtype; rng = default_rng()) μ_y = randn(rng, realtype, n_dims) σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + #model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) μ = vcat(μ_x, μ_y) L = vcat(σ_x, σ_y) |> Diagonal diff --git a/test/optimize.jl b/test/optimize.jl index 920a3070..c96fa6cd 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -6,7 +6,6 @@ using Optimisers using Distributions using PDMats using LinearAlgebra -using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M From 75eef445a5daea37d79106851b26af292de2542b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:05:08 +0100 Subject: [PATCH 109/206] fix LaTeX in docs and README --- README.md | 2 +- docs/src/started.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b3538ccf..d9638bfd 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ For example, for the normal-log-normal model: $$ \begin{aligned} -x &\sim \mathrm{Log\text{-}Normal}\left(\mu_x, \sigma_x^2\right) \\ +x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} $$ diff --git a/docs/src/started.md b/docs/src/started.md index 4a1d26ec..a129fc46 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -18,8 +18,8 @@ optimize In this tutorial, we will work with a basic `normal-log-normal` model. ```math \begin{aligned} -x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) +x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} ``` ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. From 40574f46864513ced4051867159e0660b2f4b061 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:10:29 +0100 Subject: [PATCH 110/206] add warning about forward-mode AD when using `LocationScale` --- docs/src/locscale.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/locscale.md b/docs/src/locscale.md index 63ff5cb4..8f14a9ad 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -23,6 +23,9 @@ where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. +!!! warning + `LocationScale` and its specializations such as `VIFullRankGaussian` and `VIMeanFieldGaussian` are inefficient with forward-mode differentiation packages like `ForwardDiff`. Especially, they scale poorly with the number of dimensions. Please use reverse-mode differentation packages such as `ReverseDiff` and `Zygote`. + ## Constructors ```@docs From 8738256bd44fc38dd49807a69f70da41fa50448c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:14:04 +0100 Subject: [PATCH 111/206] fix documentation --- README.md | 7 +++---- docs/src/started.md | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d9638bfd..07407fa9 100644 --- a/README.md +++ b/README.md @@ -8,17 +8,16 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije ## Examples -`AdvancedVI` basically expects a `LogDensityProblem`. +`AdvancedVI` expects a `LogDensityProblem`. For example, for the normal-log-normal model: $$ \begin{aligned} x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), \end{aligned} $$ - -A `LogDensityProblem` can be implemented as +a `LogDensityProblem` can be implemented as ```julia using LogDensityProblems diff --git a/docs/src/started.md b/docs/src/started.md index a129fc46..b07a5bd3 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -15,7 +15,7 @@ optimize ``` ## `ADVI` Example -In this tutorial, we will work with a basic `normal-log-normal` model. +In this tutorial, we will work with a `normal-log-normal` model. ```math \begin{aligned} x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ From 817374403e58cb11e4e0e3aaee045c350d5bdfdc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:18:52 +0100 Subject: [PATCH 112/206] fix remove reamining use of `@unpack` --- test/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize.jl b/test/optimize.jl index c96fa6cd..96930495 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -25,7 +25,7 @@ include("models/utils.jl") T = 1000 modelstats = normallognormal_meanfield(Float64; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats # Global Test Configurations b⁻¹ = Bijectors.bijector(model) |> inverse From e0548aecdc3468aa836d58b55aa3be60124d4782 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 22:22:02 -0400 Subject: [PATCH 113/206] Revert "remove `SimpleUnPack` dependency" This reverts commit 29d7d27ca227413275174e12f9258b13b8276fd0. --- Project.toml | 1 + docs/Project.toml | 1 + docs/src/advi.md | 9 ++++++--- docs/src/started.md | 11 +++++++---- src/AdvancedVI.jl | 1 + src/distributions/location_scale.jl | 14 +++++++------- test/Project.toml | 1 + test/advi_locscale.jl | 3 ++- test/models/normal.jl | 2 +- test/models/normallognormal.jl | 7 +++---- test/optimize.jl | 1 + 11 files changed, 31 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 29cc559f..e099308a 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/docs/Project.toml b/docs/Project.toml index 1f4ba59f..182edd3e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,6 +9,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [compat] ADTypes = "0.1.6" diff --git a/docs/src/advi.md b/docs/src/advi.md index 88c11fee..afb780cb 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -116,6 +116,7 @@ StickingTheLandingEntropy ```@setup stl using LogDensityProblems +using SimpleUnPack using PDMats using Bijectors using LinearAlgebra @@ -133,7 +134,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -150,15 +151,17 @@ n_dims = 10 σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); q0 = AVI.VIMeanFieldGaussian(μ, L) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); + function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) diff --git a/docs/src/started.md b/docs/src/started.md index b07a5bd3..4e2b4380 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -27,6 +27,7 @@ ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. Using the `LogDensityProblems` interface, we the model can be defined as follows: ```@example advi using LogDensityProblems +using SimpleUnPack struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -36,7 +37,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -50,14 +51,14 @@ end ``` Let's now instantiate the model ```@example advi -using LinearAlgebra +using PDMats n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. @@ -66,7 +67,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat using Bijectors function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -93,6 +94,8 @@ objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi +using LinearAlgebra + d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5d0c3f8d..ae0dc684 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,6 +1,7 @@ module AdvancedVI +using SimpleUnPack: @unpack, @pack! using Accessors using Random: AbstractRNG, default_rng diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 73be42b9..3113c679 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -35,42 +35,42 @@ Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) function StatsBase.entropy(q::VILocationScale) - (; location, scale, dist) = q + @unpack location, scale, dist = q n_dims = length(location) n_dims*entropy(dist) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) - (; location, scale, dist) = q + @unpack location, scale, dist = q n_dims = length(location) scale*rand(dist, n_dims) + location end function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) - (; location, scale, dist) = q + @unpack location, scale, dist = q n_dims = length(location) scale*rand(rng, dist, n_dims, num_samples) .+ location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q rand!(rng, dist, x) x .= scale*x return x += location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q rand!(rng, dist, x) x *= scale return x += location diff --git a/test/Project.toml b/test/Project.toml index 277b73c8..2f38c88f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index c6aee68b..71cf22d5 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -8,6 +8,7 @@ using Optimisers using Distributions using PDMats using LinearAlgebra +using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M @@ -47,7 +48,7 @@ include("models/utils.jl") T = 10000 modelstats = modelconstr(realtype; rng) - (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/models/normal.jl b/test/models/normal.jl index 1dfa653c..f60ad5f3 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -5,7 +5,7 @@ struct TestMvNormal{M,S} end function LogDensityProblems.logdensity(model::TestMvNormal, θ) - (; μ, Σ) = model + @unpack μ, Σ = model logpdf(MvNormal(μ, Σ), θ) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 49da5bf6..cab73cce 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) end function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -56,8 +56,7 @@ function normallognormal_meanfield(realtype; rng = default_rng()) μ_y = randn(rng, realtype, n_dims) σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - #model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) μ = vcat(μ_x, μ_y) L = vcat(σ_x, σ_y) |> Diagonal diff --git a/test/optimize.jl b/test/optimize.jl index 96930495..c1a604c1 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -6,6 +6,7 @@ using Optimisers using Distributions using PDMats using LinearAlgebra +using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M From 6ab95a096e058d21b9df1bb335d09381ce097705 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 22:23:25 -0400 Subject: [PATCH 114/206] Revert "fix remove reamining use of `@unpack`" This reverts commit 817374403e58cb11e4e0e3aaee045c350d5bdfdc. --- test/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize.jl b/test/optimize.jl index c1a604c1..920a3070 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -26,7 +26,7 @@ include("models/utils.jl") T = 1000 modelstats = normallognormal_meanfield(Float64; rng) - (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations b⁻¹ = Bijectors.bijector(model) |> inverse From f0ec242e615fb9f3f7b4b05ea2a687fa9c0e8b0c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 18:08:01 +0100 Subject: [PATCH 115/206] fix documentation for `optimize` --- src/optimize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 57ee8030..b18c8581 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,7 +35,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. +- `callback!`: Callback function called after every iteration. The signature is `cb(; est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) # Returns @@ -80,7 +80,7 @@ function optimize( stat = merge(stat, stat′) if !isnothing(callback!) - stat′ = callback!(; est_state, stat, λ, g, restructure) + stat′ = callback!(; est_state, stat, restructure, λ, g) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end From 1d4c1b6877296a7bdca5ed38c9d34c5be3acc827 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 18:08:13 +0100 Subject: [PATCH 116/206] add specializations of `Optimise.destructure` for mean-field * This fixes the poor performance of `ForwardDiff` * This prevents the zero elements of the mean-field scale being extracted --- docs/src/locscale.md | 3 --- src/distributions/location_scale.jl | 35 ++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docs/src/locscale.md b/docs/src/locscale.md index 8f14a9ad..63ff5cb4 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -23,9 +23,6 @@ where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. -!!! warning - `LocationScale` and its specializations such as `VIFullRankGaussian` and `VIMeanFieldGaussian` are inefficient with forward-mode differentiation packages like `ForwardDiff`. Especially, they scale poorly with the number of dimensions. Please use reverse-mode differentation packages such as `ReverseDiff` and `Zygote`. - ## Constructors ```@docs diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 3113c679..9ae749f2 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -19,9 +19,8 @@ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution dist ::D function VILocationScale(location::AbstractVector{<:Real}, - scale::Union{<:AbstractTriangular{<:Real}, - <:Diagonal{<:Real}}, - dist::ContinuousUnivariateDistribution) + scale ::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, + dist ::ContinuousUnivariateDistribution) # Restricting all the arguments to have the same types creates problems # with dual-variable-based AD frameworks. @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2)) @@ -31,6 +30,32 @@ end Functors.@functor VILocationScale (location, scale) +# Specialization of `Optimisers.destructure` for mean-field location-scale families. +# These are necessary because we only want to extract the diagonal elements of +# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD +# is very inefficient. +# begin +struct RestructureMeanField{L, S<:Diagonal, D} + q::VILocationScale{L, S, D} +end + +function (re::RestructureMeanField)(flat::AbstractVector) + n_dims = div(length(flat), 2) + location = first(flat, n_dims) + scale = Diagonal(last(flat, n_dims)) + VILocationScale(location, scale, re.q.dist) +end + +function Optimisers.destructure( + q::VILocationScale{L, <:Diagonal, D} +) where {L, D} + @unpack location, scale, dist = q + flat = vcat(location, diag(scale)) + n_dims = length(location) + flat, RestructureMeanField(q) +end +# end + Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) @@ -42,12 +67,12 @@ end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) From 231835f719f6fce86a4e0cf9935431b53cce75c7 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 20:01:41 +0100 Subject: [PATCH 117/206] add test for `Optimisers.destructure` specializations --- test/distributions.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/distributions.jl b/test/distributions.jl index 9b18d020..dcd20696 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,6 +1,7 @@ using ReTest using Distributions: _logpdf +using Optimisers @testset "distributions" begin @testset "$(string(covtype)) $(basedist) $(realtype)" for @@ -55,4 +56,15 @@ using Distributions: _logpdf @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) end end + + @testset "Diagonal destructure" for + n_dims = 10 + μ = zeros(n_dims) + L = ones(n_dims) + q = VIMeanFieldGaussian(μ, L |> Diagonal) + λ, re = Optimisers.destructure(q) + + @test length(λ) == 2*n_dims + @test q == re(λ) + end end From ea2d426c2c9b96de7d640e9ab0add3b4ae853892 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 21:21:54 +0100 Subject: [PATCH 118/206] add specialization of `rand` for meanfield resulting in faster AD --- src/distributions/location_scale.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 9ae749f2..7eb1f708 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -87,6 +87,16 @@ function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) scale*rand(rng, dist, n_dims, num_samples) .+ location end +# This specialization improves AD performance of the sampling path +function rand( + rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int +) where {L, D} + @unpack location, scale, dist = q + n_dims = length(location) + scale_diag = diag(scale) + scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location +end + function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) @unpack location, scale, dist = q rand!(rng, dist, x) From 3033d75938b9d37408bbe081bf73c7954aff09cf Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 21:42:16 +0100 Subject: [PATCH 119/206] add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian` --- src/distributions/location_scale.jl | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 7eb1f708..a7d9fbe4 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -6,12 +6,15 @@ The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. It generally represents any distribution for which the sampling path can be -represented as the following: +represented as follows: ```julia d = length(location) u = rand(dist, d) z = scale*u + location ``` + +!!! note + For stable convergence, the initial scale needs to be sufficiently large. """ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution location::L @@ -112,21 +115,37 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) end """ - VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true) This constructs a multivariate Gaussian distribution with a full rank covariance matrix. """ -function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) where {T <: Real} +function VIFullRankGaussian( + μ::AbstractVector{T}, + L::AbstractTriangular{T}; + check_args::Bool = true +) where {T <: Real} + @assert isposdef(L) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end q_base = Normal{T}(zero(T), one(T)) VILocationScale(μ, L, q_base) end """ - VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) + VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}; check_args = true) This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. """ -function VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T <: Real} +function VIMeanFieldGaussian( + μ::AbstractVector{T}, + L::Diagonal{T}; + check_args::Bool = true +) where {T <: Real} + @assert isposdef(L) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end q_base = Normal{T}(zero(T), one(T)) VILocationScale(μ, L, q_base) end From 0cc36c0eb9f4fc701e73e5ee835e5e0ced0c88d1 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 21:55:02 +0100 Subject: [PATCH 120/206] update documentation --- docs/src/advi.md | 5 ++--- docs/src/locscale.md | 4 ++++ src/distributions/location_scale.jl | 3 --- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index afb780cb..2cf6a773 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -210,8 +210,8 @@ _, stats_stl, _ = AVI.optimize( t = [stat.iteration for stat ∈ stats_cfe] y_cfe = [stat.elbo for stat ∈ stats_cfe] y_stl = [stat.elbo for stat ∈ stats_stl] -plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO") -plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO") +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) savefig("advi_stl_elbo.svg") nothing ``` @@ -220,7 +220,6 @@ nothing We can see that the noise of the STL estimator becomes smaller as VI converges. However, the speed of convergence may not always be significantly different. - ## References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. diff --git a/docs/src/locscale.md b/docs/src/locscale.md index 63ff5cb4..a5966f44 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -25,6 +25,10 @@ The derivative of the entropy with respect to ``\lambda`` is thus independent of ## Constructors +!!! note + For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. + Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. + ```@docs VILocationScale ``` diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index a7d9fbe4..ce14d724 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -12,9 +12,6 @@ represented as follows: u = rand(dist, d) z = scale*u + location ``` - -!!! note - For stable convergence, the initial scale needs to be sufficiently large. """ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution location::L From b7d3471fdd81b44a07dac068f1d84a260bb4959a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:54:18 +0100 Subject: [PATCH 121/206] fix type instability, bug in argument check in `LocationScale` --- src/distributions/location_scale.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index ce14d724..ab12db84 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -57,12 +57,15 @@ end # end Base.length(q::VILocationScale) = length(q.location) + Base.size(q::VILocationScale) = size(q.location) +Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D) + function StatsBase.entropy(q::VILocationScale) @unpack location, scale, dist = q n_dims = length(location) - n_dims*entropy(dist) + first(logabsdet(scale)) + n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @@ -121,7 +124,7 @@ function VIFullRankGaussian( L::AbstractTriangular{T}; check_args::Bool = true ) where {T <: Real} - @assert isposdef(L) "Scale must be positive definite" + @assert eigmin(L) > eps(eltype(L)) "Scale must be positive definite" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end @@ -139,7 +142,7 @@ function VIMeanFieldGaussian( L::Diagonal{T}; check_args::Bool = true ) where {T <: Real} - @assert isposdef(L) "Scale must be positive definite" + @assert eigmin(L) > eps(eltype(L)) "Scale must be a Cholesky factor" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end From df50e8346e2d3174c6e57f41812e25f5d9c9751e Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:57:24 +0100 Subject: [PATCH 122/206] add missing import bug --- src/AdvancedVI.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index ae0dc684..16807542 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -7,7 +7,7 @@ using Accessors using Random: AbstractRNG, default_rng using Distributions import Distributions: - logpdf, _logpdf, rand, _rand!, _rand!, + logpdf, _logpdf, rand, rand!, _rand!, ContinuousMultivariateDistribution using Functors @@ -26,7 +26,6 @@ using ADTypes: AbstractADType using ChainRules: @ignore_derivatives using FillArrays -using PDMats using Bijectors using StatsBase From ae3e9b018518b803ed60b6eaf7c5400cdf040a10 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:57:43 +0100 Subject: [PATCH 123/206] refactor test, fix type bug in tests for `LocationScale` --- test/ad.jl | 2 -- test/advi_locscale.jl | 24 +++--------------------- test/distributions.jl | 27 ++++++++++++--------------- test/models/utils.jl | 8 -------- test/optimize.jl | 18 ------------------ test/runtests.jl | 23 +++++++++++++++++++++++ 6 files changed, 38 insertions(+), 64 deletions(-) delete mode 100644 test/models/utils.jl diff --git a/test/ad.jl b/test/ad.jl index 2c4f802a..f575b485 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,7 +1,5 @@ using ReTest -using ForwardDiff, ReverseDiff, Enzyme, Zygote -using ADTypes @testset "ad" begin @testset "$(adname)" for (adname, adsymbol) ∈ Dict( diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 71cf22d5..a7dcc98b 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -2,25 +2,6 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using ReTest -using Bijectors -using LogDensityProblems -using Optimisers -using Distributions -using PDMats -using LinearAlgebra -using SimpleUnPack: @unpack - -struct TestModel{M,L,S} - model::M - μ_true::L - L_true::S - n_dims::Int - is_meanfield::Bool -end - -include("models/normallognormal.jl") -include("models/normal.jl") -include("models/utils.jl") @testset "advi" begin @testset "locscale" begin @@ -55,10 +36,11 @@ include("models/utils.jl") μ₀ = zeros(realtype, n_dims) L₀ = if is_meanfield - ones(realtype, n_dims) |> Diagonal + FillArrays.Eye(n_dims) |> Diagonal else - diagm(ones(realtype, n_dims)) |> LowerTriangular + FillArrays.Eye(n_dims) |> LowerTriangular end + q₀ = if is_meanfield VIMeanFieldGaussian(μ₀, L₀) else diff --git a/test/distributions.jl b/test/distributions.jl index dcd20696..563de12d 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,7 +1,6 @@ using ReTest using Distributions: _logpdf -using Optimisers @testset "distributions" begin @testset "$(string(covtype)) $(basedist) $(realtype)" for @@ -11,35 +10,33 @@ using Optimisers seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) rng = Philox4x(UInt64, seed, 8) - realtype = Float64 - ϵ = 1f-2 n_dims = 10 n_montecarlo = 1000_000 - μ = randn(rng, realtype, n_dims) - L₀ = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - Σ = if covtype == :fullrank - Σ = (L₀*L₀' + ϵ*I) |> Hermitian + μ = randn(rng, realtype, n_dims) + L = if covtype == :fullrank + sample_cholesky(rng, realtype, n_dims) else Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) end + Σ = L*L' - L = cholesky(Σ).L q = if covtype == :fullrank && basedist == :gaussian - VIFullRankGaussian(μ, L |> LowerTriangular) + VIFullRankGaussian(μ, L) elseif covtype == :meanfield && basedist == :gaussian - VIMeanFieldGaussian(μ, L |> Diagonal) + VIMeanFieldGaussian(μ, L) end q_true = if basedist == :gaussian MvNormal(μ, Σ) end @testset "logpdf" begin - z = randn(rng, realtype, n_dims) - @test logpdf(q, z) ≈ logpdf(q_true, z) - @test _logpdf(q, z) ≈ _logpdf(q_true, z) - @test eltype(logpdf(q, z)) == realtype - @test eltype(_logpdf(q, z)) == realtype + z = rand(rng, q) + @test eltype(z) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + @test eltype(_logpdf(q, z)) == realtype end @testset "entropy" begin diff --git a/test/models/utils.jl b/test/models/utils.jl deleted file mode 100644 index 3d483c46..00000000 --- a/test/models/utils.jl +++ /dev/null @@ -1,8 +0,0 @@ - -function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) - A = randn(rng, type, n_dims, n_dims) - L = tril(A) - idx = diagind(L) - @. L[idx] = log(exp(L[idx]) + 1) - L |> LowerTriangular -end diff --git a/test/optimize.jl b/test/optimize.jl index 920a3070..5686b724 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -1,23 +1,5 @@ using ReTest -using Bijectors -using LogDensityProblems -using Optimisers -using Distributions -using PDMats -using LinearAlgebra -using SimpleUnPack: @unpack - -struct TestModel{M,L,S} - model::M - μ_true::L - L_true::S - n_dims::Int - is_meanfield::Bool -end - -include("models/normallognormal.jl") -include("models/utils.jl") @testset "optimize" begin seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) diff --git a/test/runtests.jl b/test/runtests.jl index 6bd3bc49..803c11c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,9 +8,32 @@ using Random123 using Statistics using Distributions using LinearAlgebra +using SimpleUnPack: @unpack +using PDMats + +using Bijectors +using LogDensityProblems +using Optimisers +using ADTypes +using ForwardDiff, ReverseDiff, Zygote using AdvancedVI +# Utilities +include("utils.jl") + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("models/normal.jl") +include("models/normallognormal.jl") + +# Tests include("ad.jl") include("distributions.jl") include("advi_locscale.jl") From e4002cfeb0f8edd7dd8cf02e6ee68f1eb2bf959a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:58:08 +0100 Subject: [PATCH 124/206] add missing compat entries --- Project.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e099308a..87aa4aac 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -37,17 +36,21 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" +Accessors = "0.1.32" Bijectors = "0.11, 0.12, 0.13" ChainRules = "1.53.0" DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" +FillArrays = "1.6.0" ForwardDiff = "0.10.25" +Functors = "0.4.5" LogDensityProblems = "2.1.1" Optimisers = "0.2.16" ProgressMeter = "1.0.0" Requires = "0.5, 1.0" ReverseDiff = "1.14" +SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" julia = "1.6" From 8c82569208199480676de7b583cd54ff079ba8c5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:19:26 +0100 Subject: [PATCH 125/206] fix missing package import in test --- test/Project.toml | 1 + test/runtests.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 2f38c88f..663d671d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/runtests.jl b/test/runtests.jl index 803c11c7..8a6e486e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using Distributions using LinearAlgebra using SimpleUnPack: @unpack using PDMats +using FillArrays using Bijectors using LogDensityProblems From c2e751723a63cd00b5f223a390dc34513b94b946 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:19:34 +0100 Subject: [PATCH 126/206] add additional tests for sampling `LocationScale` --- test/distributions.jl | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/test/distributions.jl b/test/distributions.jl index 563de12d..c603421e 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -31,6 +31,9 @@ using Distributions: _logpdf end @testset "logpdf" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + z = rand(rng, q) @test eltype(z) == realtype @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) @@ -45,12 +48,38 @@ using Distributions: _logpdf end @testset "sampling" begin - z_samples = rand(rng, q, n_montecarlo) - threesigma = L - @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @testset "rand" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = mapreduce(x -> rand(rng, q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand batch" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = rand(rng, q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand!" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(rng, q, z_samples) + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end end end From 3a6f8bf689af5657d817674d84a886d3496864d6 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:19:50 +0100 Subject: [PATCH 127/206] fix bug in batch in-place `rand!` for `LocationScale` --- src/distributions/location_scale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index ab12db84..ecb0b672 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -110,8 +110,8 @@ end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) @unpack location, scale, dist = q rand!(rng, dist, x) - x *= scale - return x += location + x[:] = scale*x + return x .+= location end """ From b78ef4bf3afe6649d124320595edde71d3031e02 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:39:16 +0100 Subject: [PATCH 128/206] fix bug in inference test initialization --- test/advi_locscale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index a7dcc98b..76ae3724 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -38,7 +38,7 @@ using ReTest L₀ = if is_meanfield FillArrays.Eye(n_dims) |> Diagonal else - FillArrays.Eye(n_dims) |> LowerTriangular + FillArrays.Eye(n_dims) |> Matrix |> LowerTriangular end q₀ = if is_meanfield From a1f7e98a612bc8e7b840c4c341ee8870aac9e29f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 23 Aug 2023 01:29:50 +0100 Subject: [PATCH 129/206] add missing file --- test/utils.jl | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 test/utils.jl diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..3d483c46 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,8 @@ + +function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) + A = randn(rng, type, n_dims, n_dims) + L = tril(A) + idx = diagind(L) + @. L[idx] = log(exp(L[idx]) + 1) + L |> LowerTriangular +end From 8b783eca14a21cc620f311f9f63417e9f31e5de8 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 22 Aug 2023 21:46:01 -0400 Subject: [PATCH 130/206] fix remove use of for 1.6 --- src/distributions/location_scale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index ecb0b672..91b6768a 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -124,7 +124,7 @@ function VIFullRankGaussian( L::AbstractTriangular{T}; check_args::Bool = true ) where {T <: Real} - @assert eigmin(L) > eps(eltype(L)) "Scale must be positive definite" + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end @@ -142,7 +142,7 @@ function VIMeanFieldGaussian( L::Diagonal{T}; check_args::Bool = true ) where {T <: Real} - @assert eigmin(L) > eps(eltype(L)) "Scale must be a Cholesky factor" + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end From 12cd9f22611f3bf1a95ea878ade7c3f151957cd9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 23 Aug 2023 21:00:51 +0100 Subject: [PATCH 131/206] refactor adjust inference test hyperparameters to be more robust --- test/advi_locscale.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 76ae3724..524dc5e2 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -27,10 +27,11 @@ using ReTest seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) rng = Philox4x(UInt64, seed, 8) - T = 10000 modelstats = modelconstr(realtype; rng) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -53,7 +54,7 @@ using ReTest Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-2), + optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -72,7 +73,7 @@ using ReTest rng = Philox4x(UInt64, seed, 8) q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), + optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -83,7 +84,7 @@ using ReTest rng_repl = Philox4x(UInt64, seed, 8) q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), + optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng_repl, adbackend = adbackend, From 837c7296467ae20c66f7c061a6142295ebe50b22 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 24 Aug 2023 02:43:39 +0100 Subject: [PATCH 132/206] refactor `optimize` to return `obj_state`, add warm start kwargs --- docs/src/advi.md | 4 +-- docs/src/started.md | 4 +-- src/AdvancedVI.jl | 6 ----- src/objectives/elbo/advi.jl | 49 +++++++++++++++++-------------------- src/optimize.jl | 27 +++++++++++++------- test/advi_locscale.jl | 12 ++++----- test/optimize.jl | 14 +++++------ 7 files changed, 57 insertions(+), 59 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 2cf6a773..3ac90436 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -189,7 +189,7 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i ```@setup stl n_max_iter = 10^4 -_, stats_cfe, _ = AVI.optimize( +_, stats_cfe, _, _ = AVI.optimize( cfe, q0, n_max_iter; @@ -198,7 +198,7 @@ _, stats_cfe, _ = AVI.optimize( optimizer = Optimisers.Adam(1e-3) ); -_, stats_stl, _ = AVI.optimize( +_, stats_stl, _, _ = AVI.optimize( stl, q0, n_max_iter; diff --git a/docs/src/started.md b/docs/src/started.md index 4e2b4380..f3ae54b1 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -103,8 +103,8 @@ q = AVI.VIMeanFieldGaussian(μ, L) ``` Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. ```@example advi -n_max_iter = 10^4 -q, stats, _ = AVI.optimize( +n_max_iter = 10^4 +q, stats, _, _ = AVI.optimize( objective, q, n_max_iter; diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 16807542..9bc3d316 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -53,14 +53,8 @@ abstract type AbstractVariationalObjective end function init end function estimate_gradient end -init(::Nothing) = nothing - # ADVI-specific interfaces abstract type AbstractEntropyEstimator end -abstract type AbstractControlVariate end - -function update end -update(::Nothing, ::Nothing) = (nothing, nothing) # entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index d8719fa7..f9a61d81 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -19,18 +19,15 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ -struct ADVI{Tlogπ, B, - EntropyEst <: AbstractEntropyEstimator, - ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective - ℓπ::Tlogπ - invbij::B - entropy::EntropyEst - cv::ControlVar +struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + prob ::P + invbij ::B + entropy ::EntropyEst n_samples::Int - function ADVI(prob, n_samples::Int; - entropy::AbstractEntropyEstimator = ClosedFormEntropy(), - cv::Union{<:AbstractControlVariate, Nothing} = nothing, + function ADVI(prob, + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy(), invbij = Bijectors.identity) cap = LogDensityProblems.capabilities(prob) if cap === nothing @@ -40,15 +37,16 @@ struct ADVI{Tlogπ, B, ), ) end - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(invbij), typeof(entropy), typeof(cv)}(ℓπ, invbij, entropy, cv, n_samples) + new{typeof(prob), typeof(invbij), typeof(entropy)}( + prob, invbij, entropy, n_samples + ) end end Base.show(io::IO, advi::ADVI) = - print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))") + print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") -init(advi::ADVI) = init(advi.cv) +init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing function (advi::ADVI)( rng::AbstractRNG, @@ -57,7 +55,7 @@ function (advi::ADVI)( ) 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) - advi.ℓπ(zᵢ) + logdetjacᵢ + LogDensityProblems.logdensity(advi.prob, zᵢ) + logdetjacᵢ end ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ @@ -78,22 +76,22 @@ Evaluate the ELBO using the ADVI formulation. """ function (advi::ADVI)( - q_η::ContinuousMultivariateDistribution; - rng::AbstractRNG = default_rng(), - n_samples::Int = advi.n_samples + q_η ::ContinuousMultivariateDistribution; + rng ::AbstractRNG = default_rng(), + n_samples::Int = advi.n_samples ) ηs = rand(rng, q_η, n_samples) advi(rng, q_η, ηs) end function estimate_gradient( - rng::AbstractRNG, - adbackend::AbstractADType, - advi::ADVI, + rng ::AbstractRNG, + adbackend ::AbstractADType, + advi ::ADVI, est_state, - λ::Vector{<:Real}, + λ ::Vector{<:Real}, restructure, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) f(λ′) = begin q_η = restructure(λ′) @@ -105,8 +103,5 @@ function estimate_gradient( nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - est_state, stat′ = update(advi.cv, est_state) - stat = !isnothing(stat′) ? merge(stat′, stat) : stat - - out, est_state, stat + out, nothing, stat end diff --git a/src/optimize.jl b/src/optimize.jl index b18c8581..54e7ace0 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,13 +35,18 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. +- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) +When resuming from the state of a previous run, use the following keyword arguments: +- `opt_state`: Initial state of the optimizer. +- `obj_state`: Initial state of the objective. + # Returns - `λ`: Variational parameters optimizing the variational objective. - `stats`: Statistics gathered during inference. - `opt_state`: Final state of the optimiser. +- `obj_state`: Final state of the objective. """ function optimize( objective ::AbstractVariationalObjective, @@ -52,6 +57,8 @@ function optimize( optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::AbstractRNG = default_rng(), show_progress::Bool = true, + opt_state = nothing, + obj_state = nothing, callback! = nothing, prog = ProgressMeter.Progress( n_max_iter; @@ -62,16 +69,16 @@ function optimize( ) ) λ = copy(λ₀) - opt_state = Optimisers.setup(optimizer, λ) - est_state = init(objective) + opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ) : opt_state + obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state grad_buf = DiffResults.GradientResult(λ) stats = NamedTuple[] for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, est_state, stat′ = estimate_gradient( - rng, adbackend, objective, est_state, λ, restructure, grad_buf) + grad_buf, obj_state, stat′ = estimate_gradient( + rng, adbackend, objective, obj_state, λ, restructure, grad_buf) stat = merge(stat, stat′) g = DiffResults.gradient(grad_buf) @@ -80,7 +87,7 @@ function optimize( stat = merge(stat, stat′) if !isnothing(callback!) - stat′ = callback!(; est_state, stat, restructure, λ, g) + stat′ = callback!(; obj_state, stat, restructure, λ, g) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end @@ -89,7 +96,7 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) end - λ, map(identity, stats), opt_state + λ, map(identity, stats), opt_state, obj_state end function optimize(objective ::AbstractVariationalObjective, @@ -97,6 +104,8 @@ function optimize(objective ::AbstractVariationalObjective, n_max_iter::Int; kwargs...) λ, restructure = Optimisers.destructure(q₀) - λ, stats, opt_state = optimize(objective, restructure, λ, n_max_iter; kwargs...) - restructure(λ), stats, opt_state + λ, stats, opt_state, obj_state = optimize( + objective, restructure, λ, n_max_iter; kwargs... + ) + restructure(λ), stats, opt_state, obj_state end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 524dc5e2..e780b074 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -51,8 +51,8 @@ using ReTest obj = objective(model, b⁻¹, 10) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _ = optimize( + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, @@ -70,8 +70,8 @@ using ReTest end @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + q, stats, _, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, @@ -81,8 +81,8 @@ using ReTest μ = q.location L = q.scale - rng_repl = Philox4x(UInt64, seed, 8) - q, stats, _ = optimize( + rng_repl = Philox4x(UInt64, seed, 8) + q, stats, _, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, diff --git a/test/optimize.jl b/test/optimize.jl index 5686b724..2369432c 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -20,8 +20,8 @@ using ReTest adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) - rng = Philox4x(UInt64, seed, 8) - q_ref, stats_ref, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + q_ref, stats_ref, _, _ = optimize( obj, q₀, T; optimizer, show_progress = false, @@ -33,8 +33,8 @@ using ReTest @testset "restructure" begin λ₀, re = Optimisers.destructure(q₀) - rng = Philox4x(UInt64, seed, 8) - λ, stats, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + λ, stats, _, _ = optimize( obj, re, λ₀, T; optimizer, show_progress = false, @@ -49,12 +49,12 @@ using ReTest rng = Philox4x(UInt64, seed, 8) test_values = rand(rng, T) - callback!(; stat, est_state, restructure, λ, g) = begin + callback!(; stat, obj_state, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end - rng = Philox4x(UInt64, seed, 8) - _, stats, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + _, stats, _, _ = optimize( obj, q₀, T; show_progress = false, rng, From 95629a5471f7e3e94a19b8096cd9df73d8dad523 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 23 Aug 2023 23:19:09 -0400 Subject: [PATCH 133/206] refactor make tests more robust, reduce amount of tests --- test/advi_locscale.jl | 2 -- test/distributions.jl | 2 +- test/models/normal.jl | 50 ---------------------------------- test/models/normallognormal.jl | 2 +- test/runtests.jl | 5 +--- test/utils.jl | 8 ------ 6 files changed, 3 insertions(+), 66 deletions(-) delete mode 100644 test/models/normal.jl delete mode 100644 test/utils.jl diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index e780b074..d5250ce8 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -10,8 +10,6 @@ using ReTest (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, - :NormalMeanField => normal_meanfield, - :NormalFullRank => normal_fullrank, ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), diff --git a/test/distributions.jl b/test/distributions.jl index c603421e..175cc96b 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -15,7 +15,7 @@ using Distributions: _logpdf μ = randn(rng, realtype, n_dims) L = if covtype == :fullrank - sample_cholesky(rng, realtype, n_dims) + tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular else Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) end diff --git a/test/models/normal.jl b/test/models/normal.jl deleted file mode 100644 index f60ad5f3..00000000 --- a/test/models/normal.jl +++ /dev/null @@ -1,50 +0,0 @@ - -struct TestMvNormal{M,S} - μ::M - Σ::S -end - -function LogDensityProblems.logdensity(model::TestMvNormal, θ) - @unpack μ, Σ = model - logpdf(MvNormal(μ, Σ), θ) -end - -function LogDensityProblems.dimension(model::TestMvNormal) - length(model.μ) -end - -function LogDensityProblems.capabilities(::Type{<:TestMvNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::TestMvNormal) - identity -end - -function normal_fullrank(realtype; rng = default_rng()) - n_dims = 5 - - μ = randn(rng, realtype, n_dims) - L₀ = sample_cholesky(rng, realtype, n_dims) - Σ = L₀*L₀' |> Hermitian - - Σ_chol = cholesky(Σ) - model = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol)) - - L = Σ_chol.L |> LowerTriangular - - TestModel(model, μ, L, n_dims, false) -end - -function normal_meanfield(realtype; rng = default_rng()) - n_dims = 5 - - μ = randn(rng, realtype, n_dims) - σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - - model = TestMvNormal(μ, PDMats.PDiagMat(σ)) - - L = σ |> Diagonal - - TestModel(model, μ, L, n_dims, true) -end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index cab73cce..f8b84a1b 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,7 +32,7 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L₀_y = sample_cholesky(rng, realtype, n_dims) + L₀_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular Σ_y = L₀_y*L₀_y' |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) diff --git a/test/runtests.jl b/test/runtests.jl index 8a6e486e..0a2c5e66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,9 +20,7 @@ using ForwardDiff, ReverseDiff, Zygote using AdvancedVI -# Utilities -include("utils.jl") - +# Models for Inference Tests struct TestModel{M,L,S} model::M μ_true::L @@ -31,7 +29,6 @@ struct TestModel{M,L,S} is_meanfield::Bool end -include("models/normal.jl") include("models/normallognormal.jl") # Tests diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 3d483c46..00000000 --- a/test/utils.jl +++ /dev/null @@ -1,8 +0,0 @@ - -function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) - A = randn(rng, type, n_dims, n_dims) - L = tril(A) - idx = diagind(L) - @. L[idx] = log(exp(L[idx]) + 1) - L |> LowerTriangular -end From 0b4b865ae9376b35b776afca17baf58cea27b095 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:31:09 -0400 Subject: [PATCH 134/206] fix remove a cholesky in test model --- test/models/normallognormal.jl | 14 +++++++------- test/runtests.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index f8b84a1b..ec591f2c 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -29,13 +29,13 @@ end function normallognormal_fullrank(realtype; rng = default_rng()) n_dims = 5 - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L₀_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular - Σ_y = L₀_y*L₀_y' |> Hermitian + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular + Σ_y = L_y*L_y' |> Hermitian - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y))) Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) Σ[1,1] = σ_x^2 @@ -56,7 +56,7 @@ function normallognormal_meanfield(realtype; rng = default_rng()) μ_y = randn(rng, realtype, n_dims) σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) μ = vcat(μ_x, μ_y) L = vcat(σ_x, σ_y) |> Diagonal diff --git a/test/runtests.jl b/test/runtests.jl index 0a2c5e66..127503be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,8 +9,8 @@ using Statistics using Distributions using LinearAlgebra using SimpleUnPack: @unpack -using PDMats using FillArrays +using PDMats using Bijectors using LogDensityProblems From b49f4ebc163e2feecba38fba2678e650dfbd788d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:31:34 -0400 Subject: [PATCH 135/206] fix compat bounds, remove unused package --- Project.toml | 28 ++++++++++++++-------------- src/AdvancedVI.jl | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 87aa4aac..143e2098 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,7 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -20,7 +20,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -36,23 +35,24 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" -Accessors = "0.1.32" -Bijectors = "0.11, 0.12, 0.13" -ChainRules = "1.53.0" +Accessors = "0.1" +Bijectors = "0.12, 0.13" +ChainRulesCore = "1.16" DiffResults = "1" -Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" +Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" -FillArrays = "1.6.0" -ForwardDiff = "0.10.25" -Functors = "0.4.5" -LogDensityProblems = "2.1.1" +Enzyme = "0.11.7" +FillArrays = "1.3" +ForwardDiff = "0.10.36" +Functors = "0.4" +LogDensityProblems = "2" Optimisers = "0.2.16" -ProgressMeter = "1.0.0" -Requires = "0.5, 1.0" -ReverseDiff = "1.14" +ProgressMeter = "1.6" +Requires = "1.0" +ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -StatsFuns = "0.8, 0.9, 1" +Zygote = "0.6.63" julia = "1.6" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9bc3d316..7272303a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -23,7 +23,7 @@ using LogDensityProblems using ADTypes, DiffResults using ADTypes: AbstractADType -using ChainRules: @ignore_derivatives +using ChainRulesCore: @ignore_derivatives using FillArrays using Bijectors From 947a070da945505282711f6a45f6c3723b32b7fd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:32:51 -0400 Subject: [PATCH 136/206] bump compat for ADTypes 0.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 143e2098..075ae92f 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ AdvancedVIReverseDiffExt = "ReverseDiff" AdvancedVIZygoteExt = "Zygote" [compat] -ADTypes = "0.1" +ADTypes = "0.1, 0.2" Accessors = "0.1" Bijectors = "0.12, 0.13" ChainRulesCore = "1.16" From a9b3f483f4ae3bd4ac2d569d21697c8a786c448c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:35:32 -0400 Subject: [PATCH 137/206] fix broken LaTeX in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 07407fa9..86a57cb6 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), \end{aligned} $$ + a `LogDensityProblem` can be implemented as ```julia using LogDensityProblems From 54826eb51c0a64bd7fd85b9363300c28e77381d7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:52:35 -0400 Subject: [PATCH 138/206] remove redundant use of PDMats in docs --- README.md | 9 ++++----- docs/Project.toml | 1 - docs/src/advi.md | 5 +---- docs/src/started.md | 6 ++---- 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 86a57cb6..695e9ed9 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, using Bijectors function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -60,19 +60,18 @@ A simpler approach is to use `Turing`, where a `Turing.Model` can be automatical Let us instantiate a random normal-log-normal model. ```julia -using PDMats +using LinearAlgebra n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) ``` ADVI can be used as follows: ```julia -using LinearAlgebra using Optimisers using ADTypes, ForwardDiff import AdvancedVI as AVI @@ -81,7 +80,7 @@ b = Bijectors.bijector(model) b⁻¹ = inverse(b) # ADVI objective -objective = AVI.ADVI(model, 10; b=b⁻¹) +objective = AVI.ADVI(model, 10; invbij=b⁻¹) # Mean-field Gaussian variational family d = LogDensityProblems.dimension(model) diff --git a/docs/Project.toml b/docs/Project.toml index 182edd3e..568be1b6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,7 +7,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" diff --git a/docs/src/advi.md b/docs/src/advi.md index 3ac90436..2773dda7 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -117,7 +117,6 @@ StickingTheLandingEntropy ```@setup stl using LogDensityProblems using SimpleUnPack -using PDMats using Bijectors using LinearAlgebra using Plots @@ -151,15 +150,13 @@ n_dims = 10 σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); q0 = AVI.VIMeanFieldGaussian(μ, L) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); - function Bijectors.bijector(model::NormalLogNormal) @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( diff --git a/docs/src/started.md b/docs/src/started.md index f3ae54b1..e8392fd7 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -51,14 +51,14 @@ end ``` Let's now instantiate the model ```@example advi -using PDMats +using LinearAlgebra n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. @@ -94,8 +94,6 @@ objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi -using LinearAlgebra - d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); From 1d1c8ffd320463b6bd9a552227270bb2837344b0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 01:09:27 -0400 Subject: [PATCH 139/206] fix use `Cholesky` signature supported in 1.6 --- test/models/normallognormal.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index ec591f2c..e2b9e816 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,10 +32,10 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 Σ_y = L_y*L_y' |> Hermitian - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y))) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) Σ[1,1] = σ_x^2 From 7bac95b1dea4b15df7844602966ebf539ee43fe9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 01:34:21 -0400 Subject: [PATCH 140/206] revert custom variational families and docs --- docs/Project.toml | 17 - docs/make.jl | 22 - docs/src/advi.md | 227 -------- docs/src/index.md | 14 - docs/src/locscale.md | 85 --- docs/src/started.md | 132 ----- src/AdvancedVI.jl | 9 - src/distributions/location_scale.jl | 151 ----- test/Manifest.toml | 866 ++++++++++++++++++++++++++++ test/Project.toml | 2 + test/advi_locscale.jl | 30 +- test/distributions.jl | 96 --- test/optimize.jl | 4 +- test/runtests.jl | 5 +- 14 files changed, 883 insertions(+), 777 deletions(-) delete mode 100644 docs/Project.toml delete mode 100644 docs/make.jl delete mode 100644 docs/src/advi.md delete mode 100644 docs/src/index.md delete mode 100644 docs/src/locscale.md delete mode 100644 docs/src/started.md delete mode 100644 src/distributions/location_scale.jl create mode 100644 test/Manifest.toml delete mode 100644 test/distributions.jl diff --git a/docs/Project.toml b/docs/Project.toml deleted file mode 100644 index 568be1b6..00000000 --- a/docs/Project.toml +++ /dev/null @@ -1,17 +0,0 @@ -[deps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" - -[compat] -ADTypes = "0.1.6" -Bijectors = "0.13.6" -Documenter = "0.26, 0.27" -LogDensityProblems = "2.1.1" diff --git a/docs/make.jl b/docs/make.jl deleted file mode 100644 index 5d371608..00000000 --- a/docs/make.jl +++ /dev/null @@ -1,22 +0,0 @@ - -using AdvancedVI -using Documenter - -DocMeta.setdocmeta!( - AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true -) - -makedocs(; - modules = [AdvancedVI], - sitename = "AdvancedVI.jl", - repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", - format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["AdvancedVI" => "index.md", - "Getting Started" => "started.md", - "ELBO Maximization" => [ - "Automatic Differentiation VI" => "advi.md", - "Location Scale Family" => "locscale.md", - ]], -) - -deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/advi.md b/docs/src/advi.md deleted file mode 100644 index 2773dda7..00000000 --- a/docs/src/advi.md +++ /dev/null @@ -1,227 +0,0 @@ - -# [Automatic Differentiation Variational Inference](@id advi) - -## Introduction - -The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. -By maximizing ADVI objective, it is equivalent to solving the problem - -```math - \mathrm{minimize}_{\lambda \in \Lambda}\quad \mathrm{KL}\left(q_{\phi,\lambda}, \pi\right). -``` - -The key aspects of the ADVI objective are the followings: -1. The use of the reparameterization gradient estimator -2. Automatically match the support of the target posterior through "bijectors." - -Thanks to Item 2, the user is free to choose any unconstrained variational family, for which -bijectors will automatically match the potentially constrained support of the target. - -In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}`` -from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that -```math -z \sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} -``` -ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``. - -That is, - -```math -\begin{aligned} -\mathrm{ADVI}\left(\lambda\right) -&\triangleq -\mathbb{E}_{\eta \sim q_{\lambda}}\left[ - \log \pi\left( \phi^{-1}\left( \eta \right) \right) -\right] -+ \mathbb{H}\left(q_{\lambda}\right) -+ \log \lvert J_{\phi^{-1}}\left(\eta\right) \rvert \\ -&= -\mathbb{E}_{\eta \sim q_{\lambda}}\left[ - \log \pi\left( \phi^{-1}\left( \eta \right) \right) -\right] -+ -\mathbb{E}_{\eta \sim q_{\lambda}}\left[ - - \log q_{\lambda}\left( \eta \right) \lvert J_{\phi}\left(\eta\right) \rvert -\right] \\ -&= -\mathbb{E}_{z \sim q_{\phi,\lambda}}\left[ \log \pi\left(z\right) \right] -+ -\mathbb{H}\left(q_{\phi,\lambda}\right) -\end{aligned} -``` - -The idea of using the reparameterization gradient estimator for variational inference was first -coined by Titsias and Lázaro-Gredilla (2014). -Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by -Fjelde *et al.* (2017). - -## The `ADVI` Objective - -```@docs -ADVI -``` - -## The `StickingTheLanding` Control Variate - -The STL control variate was proposed by Roeder *et al.* (2017). -By slightly modifying the differentiation path, it implicitly forms a control variate of the form of -```math -\begin{aligned} - \mathrm{CV}_{\mathrm{STL}}\left(z\right) - &\triangleq - \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) \\ - &= - -\nabla_{\lambda} \mathbb{E}_{z \sim q_{\nu}} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) -\end{aligned} -``` -where ``\nu = \lambda`` is set to avoid differentiating through the density of ``q_{\lambda}``. -We can see that this vector-valued function has a mean of zero and is therefore a valid control variate. - -Adding this to the closed-form entropy ELBO estimator yields the STL estimator: -```math -\begin{aligned} - \widehat{\nabla \mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) - &\triangleq \mathbb{E}_{u \sim \varphi}\left[ - \nabla_{\lambda} \log \pi \left(z_{\lambda}\left(u\right)\right) - - - \nabla_{\lambda} \log q_{\nu} \left(z_{\lambda}\left(u\right)\right) - \right] - \\ - &= - \mathbb{E}\left[ \nabla_{\lambda} \log \pi\left(z_{\lambda}\left(u\right)\right) \right] - + - \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) - - - \mathrm{CV}_{\mathrm{STL}}\left(z\right) - \\ - &= - \widehat{\nabla \mathrm{ELBO}}\left(\lambda\right) - - - \mathrm{CV}_{\mathrm{STL}}\left(z\right), -\end{aligned} -``` -which has the same expectation as the original ADVI estimator, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. -The conditions for which the STL estimator results in lower variance is still an active subject for research. - -The main downside of the STL estimator is that it needs to evaluate and differentiate the log density of ``q_{\lambda}`` in every iteration. -Depending on the variational family, this might be computationally inefficient or even numerically unstable. -For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability. - - -The STL control variate can be used by changing the entropy estimator using the following object: -```@docs -StickingTheLandingEntropy -``` - -```@setup stl -using LogDensityProblems -using SimpleUnPack -using Bijectors -using LinearAlgebra -using Plots - -using Optimisers -using ADTypes, ForwardDiff -import AdvancedVI as AVI - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -n_dims = 10 -μ_x = randn() -σ_x = exp.(randn()) -μ_y = randn(n_dims) -σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); - -d = LogDensityProblems.dimension(model); -μ = randn(d); -L = Diagonal(ones(d)); -q0 = AVI.VIMeanFieldGaussian(μ, L) - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end -``` - -Let us come back to the example in [Getting Started](@ref getting_started), where a `LogDensityProblem` is given as `model`. -In this example, the true posterior is contained within the variational family. -This setting is known as "perfect variational family specification." -In this case, the STL estimator is able to converge exponentially fast to the true solution. - -Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows: -```@example stl -n_montecarlo = 1; -b = Bijectors.bijector(model); -b⁻¹ = inverse(b) - -cfe = AVI.ADVI(model, n_montecarlo; invbij = b⁻¹) -``` -The STL estimator can instead be created as follows: -```@example stl -stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), invbij = b⁻¹); -``` - -```@setup stl -n_max_iter = 10^4 - -_, stats_cfe, _, _ = AVI.optimize( - cfe, - q0, - n_max_iter; - show_progress = false, - adbackend = AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) -); - -_, stats_stl, _, _ = AVI.optimize( - stl, - q0, - n_max_iter; - show_progress = false, - adbackend = AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) -); - -t = [stat.iteration for stat ∈ stats_cfe] -y_cfe = [stat.elbo for stat ∈ stats_cfe] -y_stl = [stat.elbo for stat ∈ stats_stl] -plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) -plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) -savefig("advi_stl_elbo.svg") -nothing -``` -![](advi_stl_elbo.svg) - -We can see that the noise of the STL estimator becomes smaller as VI converges. -However, the speed of convergence may not always be significantly different. - -## References -1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. -2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. -3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. -4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. -5. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. - - diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index dea6d405..00000000 --- a/docs/src/index.md +++ /dev/null @@ -1,14 +0,0 @@ -```@meta -CurrentModule = AdvancedVI -``` - -# AdvancedVI - -## Introduction -[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. -VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. -`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. - -## Provided Algorithms -`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization: -- [Automatic Differentiation Variational Inference](@ref advi) diff --git a/docs/src/locscale.md b/docs/src/locscale.md deleted file mode 100644 index a5966f44..00000000 --- a/docs/src/locscale.md +++ /dev/null @@ -1,85 +0,0 @@ - -# [Location-Scale Variational Family](@id locscale) - -## Introduction -The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as -```math -z \sim q_{\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} C u + m;\quad u \sim \varphi -``` -where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. -``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. -The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. - -The probability density is given by -```math - q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) -``` -and the entropy is given as -```math - \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|, -``` -where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. -Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. -The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. - -## Constructors - -!!! note - For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. - Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. - -```@docs -VILocationScale -``` - -```@docs -VIFullRankGaussian -VIMeanFieldGaussian -``` - -## Gaussian Variational Families - -Gaussian variational family: -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); - -L = diagm(ones(2)) |> LowerTriangular; -q = VIFullRankGaussian(μ, L) - -L = ones(2) |> Diagonal; -q = VIMeanFieldGaussian(μ, L) -``` - -## Non-Gaussian Variational Families -Sudent-T Variational Family: - -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); -ν = 3; - -# Full-Rank -L = diagm(ones(2)) |> LowerTriangular; -q = VILocationScale(μ, L, TDist(ν)) - -# Mean-Field -L = ones(2) |> Diagonal; -q = VILocationScale(μ, L, TDist(ν)) -``` - -Multivariate Laplace family: -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); - -# Full-Rank -L = diagm(ones(2)) |> LowerTriangular; -q = VILocationScale(μ, L, Laplace()) - -# Mean-Field -L = ones(2) |> Diagonal; -q = VILocationScale(μ, L, Laplace()) -``` - diff --git a/docs/src/started.md b/docs/src/started.md deleted file mode 100644 index e8392fd7..00000000 --- a/docs/src/started.md +++ /dev/null @@ -1,132 +0,0 @@ - -# [Getting Started with `AdvancedVI`](@id getting_started) - -## General Usage -Each VI algorithm provides the followings: -1. Variational families supported by each VI algorithm. -2. A variational objective corresponding to the VI algorithm. -Note that each variational family is subject to its own constraints. -Thus, please refer to the documentation of the variational inference algorithm of interest. - -To use `AdvancedVI`, a user needs to select a `variational family`, `variational objective`, and feed them into `optimize`. - -```@docs -optimize -``` - -## `ADVI` Example -In this tutorial, we will work with a `normal-log-normal` model. -```math -\begin{aligned} -x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) -\end{aligned} -``` -ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. - -Using the `LogDensityProblems` interface, we the model can be defined as follows: -```@example advi -using LogDensityProblems -using SimpleUnPack - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end -``` -Let's now instantiate the model -```@example advi -using LinearAlgebra - -n_dims = 10 -μ_x = randn() -σ_x = exp.(randn()) -μ_y = randn(n_dims) -σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); -``` - -Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. -Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. -```@example advi -using Bijectors - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -b = Bijectors.bijector(model); -b⁻¹ = inverse(b) -``` - -Let's now load `AdvancedVI`. -Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`. -Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. -Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. -```@example advi -using Optimisers -using ADTypes, ForwardDiff -import AdvancedVI as AVI -``` -We now need to select 1. a variational objective, and 2. a variational family. -Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. -```@example advi -n_montecaro = 10; -objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) -``` -For the variational family, we will use the classic mean-field Gaussian family. -```@example advi -d = LogDensityProblems.dimension(model); -μ = randn(d); -L = Diagonal(ones(d)); -q = AVI.VIMeanFieldGaussian(μ, L) -``` -Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. -```@example advi -n_max_iter = 10^4 -q, stats, _, _ = AVI.optimize( - objective, - q, - n_max_iter; - adbackend = AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) -); -``` - -The selected inference procedure stores per-iteration statistics into `stats`. -For instance, the ELBO can be ploted as follows: -```@example advi -using Plots - -t = [stat.iteration for stat ∈ stats] -y = [stat.elbo for stat ∈ stats] -plot(t, y, label="ADVI", xlabel="Iteration", ylabel="ELBO") -savefig("advi_example_elbo.svg") -nothing -``` -![](advi_example_elbo.svg) - -Further information can be gathered by defining your own `callback!`. - -The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: -```@example advi -ELBO = objective(q; n_samples=10^4) -``` diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7272303a..da8b05bb 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -67,15 +67,6 @@ export StickingTheLandingEntropy, MonteCarloEntropy -# Variational Families - -include("distributions/location_scale.jl") - -export - VILocationScale, - VIFullRankGaussian, - VIMeanFieldGaussian - # Optimization Routine function optimize end diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl deleted file mode 100644 index 91b6768a..00000000 --- a/src/distributions/location_scale.jl +++ /dev/null @@ -1,151 +0,0 @@ - -""" - VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution - -The location scale variational family broadly represents various variational -families using `location` and `scale` variational parameters. - -It generally represents any distribution for which the sampling path can be -represented as follows: -```julia - d = length(location) - u = rand(dist, d) - z = scale*u + location -``` -""" -struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution - location::L - scale ::S - dist ::D - - function VILocationScale(location::AbstractVector{<:Real}, - scale ::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, - dist ::ContinuousUnivariateDistribution) - # Restricting all the arguments to have the same types creates problems - # with dual-variable-based AD frameworks. - @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2)) - new{typeof(location), typeof(scale), typeof(dist)}(location, scale, dist) - end -end - -Functors.@functor VILocationScale (location, scale) - -# Specialization of `Optimisers.destructure` for mean-field location-scale families. -# These are necessary because we only want to extract the diagonal elements of -# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD -# is very inefficient. -# begin -struct RestructureMeanField{L, S<:Diagonal, D} - q::VILocationScale{L, S, D} -end - -function (re::RestructureMeanField)(flat::AbstractVector) - n_dims = div(length(flat), 2) - location = first(flat, n_dims) - scale = Diagonal(last(flat, n_dims)) - VILocationScale(location, scale, re.q.dist) -end - -function Optimisers.destructure( - q::VILocationScale{L, <:Diagonal, D} -) where {L, D} - @unpack location, scale, dist = q - flat = vcat(location, diag(scale)) - n_dims = length(location) - flat, RestructureMeanField(q) -end -# end - -Base.length(q::VILocationScale) = length(q.location) - -Base.size(q::VILocationScale) = size(q.location) - -Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D) - -function StatsBase.entropy(q::VILocationScale) - @unpack location, scale, dist = q - n_dims = length(location) - n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) -end - -function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) -end - -function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) -end - -function rand(q::VILocationScale) - @unpack location, scale, dist = q - n_dims = length(location) - scale*rand(dist, n_dims) + location -end - -function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) - @unpack location, scale, dist = q - n_dims = length(location) - scale*rand(rng, dist, n_dims, num_samples) .+ location -end - -# This specialization improves AD performance of the sampling path -function rand( - rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int -) where {L, D} - @unpack location, scale, dist = q - n_dims = length(location) - scale_diag = diag(scale) - scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location -end - -function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) - @unpack location, scale, dist = q - rand!(rng, dist, x) - x .= scale*x - return x += location -end - -function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) - @unpack location, scale, dist = q - rand!(rng, dist, x) - x[:] = scale*x - return x .+= location -end - -""" - VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true) - -This constructs a multivariate Gaussian distribution with a full rank covariance matrix. -""" -function VIFullRankGaussian( - μ::AbstractVector{T}, - L::AbstractTriangular{T}; - check_args::Bool = true -) where {T <: Real} - @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" - if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) - @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." - end - q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base) -end - -""" - VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}; check_args = true) - -This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. -""" -function VIMeanFieldGaussian( - μ::AbstractVector{T}, - L::Diagonal{T}; - check_args::Bool = true -) where {T <: Real} - @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" - if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) - @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." - end - q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base) -end diff --git a/test/Manifest.toml b/test/Manifest.toml new file mode 100644 index 00000000..220b42bb --- /dev/null +++ b/test/Manifest.toml @@ -0,0 +1,866 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.9.2" +manifest_format = "2.0" +project_hash = "a6495d9f0ea044fd0a55c1c989f1adca1ad5c855" + +[[deps.ADTypes]] +git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "0.2.1" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.6.2" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Bijectors]] +deps = ["ArgCheck", "ChainRules", "ChainRulesCore", "ChangesOfVariables", "Compat", "Distributions", "Functors", "InverseFunctions", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "MappedArrays", "Random", "Reexport", "Requires", "Roots", "SparseArrays", "Statistics"] +git-tree-sha1 = "af192c7c235264bdc6f67321fd1c57be0dd7ffb5" +uuid = "76274a88-744f-5084-9051-94815aaf08c4" +version = "0.13.6" + + [deps.Bijectors.extensions] + BijectorsDistributionsADExt = "DistributionsAD" + BijectorsForwardDiffExt = "ForwardDiff" + BijectorsLazyArraysExt = "LazyArrays" + BijectorsReverseDiffExt = "ReverseDiff" + BijectorsTrackerExt = "Tracker" + BijectorsZygoteExt = "Zygote" + + [deps.Bijectors.weakdeps] + DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.2" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.53.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.16.0" + +[[deps.ChangesOfVariables]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.8" +weakdeps = ["InverseFunctions"] + + [deps.ChangesOfVariables.extensions] + ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" + +[[deps.CommonSolve]] +git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" +uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +version = "0.2.4" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Comonicon]] +deps = ["Configurations", "ExproniconLite", "Libdl", "Logging", "Markdown", "OrderedCollections", "PackageCompiler", "Pkg", "Scratch", "TOML", "UUIDs"] +git-tree-sha1 = "9c360961f23e2fae4c6549bbba58a6f39c9e145c" +uuid = "863f3e99-da2a-4334-8734-de3dacbe5542" +version = "1.0.5" + +[[deps.Compat]] +deps = ["UUIDs"] +git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.9.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.5+0" + +[[deps.Configurations]] +deps = ["ExproniconLite", "OrderedCollections", "TOML"] +git-tree-sha1 = "434f446dbf89d08350e83bf57c0fc86f5d3ffd4e" +uuid = "5218b696-f38b-4ac9-8b61-a12ec717816d" +version = "0.17.5" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.3" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.DataAPI]] +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.15.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.15" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Distributions]] +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.100" + + [deps.Distributions.extensions] + DistributionsChainRulesCoreExt = "ChainRulesCore" + DistributionsDensityInterfaceExt = "DensityInterface" + + [deps.Distributions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" + +[[deps.DistributionsAD]] +deps = ["Adapt", "ChainRules", "ChainRulesCore", "Compat", "Distributions", "FillArrays", "LinearAlgebra", "PDMats", "Random", "Requires", "SpecialFunctions", "StaticArrays", "StatsFuns", "ZygoteRules"] +git-tree-sha1 = "975de103eb2175cf54bf14b15ded2c68625eabdf" +uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +version = "0.6.52" + + [deps.DistributionsAD.extensions] + DistributionsADForwardDiffExt = "ForwardDiff" + DistributionsADLazyArraysExt = "LazyArrays" + DistributionsADReverseDiffExt = "ReverseDiff" + DistributionsADTrackerExt = "Tracker" + + [deps.DistributionsAD.weakdeps] + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + +[[deps.Enzyme]] +deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random"] +git-tree-sha1 = "1f85bc8a9da6118abb95d134efc68cf4a6957341" +uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" +version = "0.11.7" + +[[deps.EnzymeCore]] +deps = ["Adapt"] +git-tree-sha1 = "643995502bdfff08bf080212c92430510be01ad5" +uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" +version = "0.5.2" + +[[deps.Enzyme_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "ffa4926cc857bcc5c256825bd7273a6ac989eb34" +uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" +version = "0.0.80+0" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.ExproniconLite]] +deps = ["Pkg", "TOML"] +git-tree-sha1 = "d80b5d5990071086edf5de9018c6c69c83937004" +uuid = "55351af7-c7e9-48d6-89ff-24e801d99491" +version = "0.10.3" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "048dd3d82558759476cff9cff999219216932a08" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.6.0" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.5" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.8.1" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.5" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.21.4" + +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.23" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.10" + +[[deps.InlineTest]] +deps = ["Test"] +git-tree-sha1 = "daf0743879904f0ad645ca6594e1479685f158a2" +uuid = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" +version = "0.2.0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.12" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.8" +weakdeps = ["EnzymeCore"] + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "6.1.0" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.23+0" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogDensityProblems]] +deps = ["ArgCheck", "DocStringExtensions", "Random"] +git-tree-sha1 = "f9a11237204bc137617194d79d813069838fcf61" +uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +version = "2.1.1" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.26" +weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"] + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.11" + +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+0" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.1.0" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.10.11" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.4" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.ObjectFile]] +deps = ["Reexport", "StructIO"] +git-tree-sha1 = "69607899b46e1f8ead70396bc51a4c361478d8f6" +uuid = "d8793406-e978-5875-9003-1fc021f44a92" +version = "0.4.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.21+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.20" + +[[deps.OrderedCollections]] +git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.2" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.17" + +[[deps.PackageCompiler]] +deps = ["Artifacts", "LazyArtifacts", "Libdl", "Pkg", "Printf", "RelocatableFolders", "TOML", "UUIDs"] +git-tree-sha1 = "1a6a868eb755e8ea9ecd000aa6ad175def0cc85b" +uuid = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d" +version = "2.1.7" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.9.2" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.8.2" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.6.1" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.ReTest]] +deps = ["Distributed", "InlineTest", "Printf", "Random", "Sockets", "Test"] +git-tree-sha1 = "dd8f6587c0abac44bcec2e42f0aeddb73550c0ec" +uuid = "e0db7c4e-2690-44b9-bad6-7687da720f89" +version = "0.3.2" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.0" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.ReverseDiff]] +deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] +git-tree-sha1 = "d1235bdd57a93bd7504225b792b867e9a7df38d5" +uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +version = "1.15.1" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.1" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.4.0+0" + +[[deps.Roots]] +deps = ["ChainRulesCore", "CommonSolve", "Printf", "Setfield"] +git-tree-sha1 = "ff42754a57bb0d6dcfe302fd0d4272853190421f" +uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +version = "2.0.19" + + [deps.Roots.extensions] + RootsForwardDiffExt = "ForwardDiff" + RootsIntervalRootFindingExt = "IntervalRootFinding" + RootsSymPyExt = "SymPy" + + [deps.Roots.weakdeps] + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807" + SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.SimpleUnPack]] +git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437" +uuid = "ce78b400-467f-4804-87d8-8f486da07d0a" +version = "1.1.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.1.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.3.1" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore"] +git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.6.2" +weakdeps = ["Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.6.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.0" + +[[deps.StatsFuns]] +deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.3.0" +weakdeps = ["ChainRulesCore", "InverseFunctions"] + + [deps.StatsFuns.extensions] + StatsFunsChainRulesCoreExt = "ChainRulesCore" + StatsFunsInverseFunctionsExt = "InverseFunctions" + +[[deps.StructArrays]] +deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] +git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.15" + +[[deps.StructIO]] +deps = ["Test"] +git-tree-sha1 = "010dc73c7146869c042b49adcdb6bf528c12e859" +uuid = "53d494c1-5632-5724-8f4c-31dff12d585f" +version = "0.3.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.10.1" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.23" + +[[deps.Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.2.26" +weakdeps = ["PDMats"] + + [deps.Tracker.extensions] + TrackerPDMatsExt = "PDMats" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.3" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+0" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.63" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/test/Project.toml b/test/Project.toml index 663d671d..5ce8fcd8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,9 +3,11 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index d5250ce8..f2ce94a5 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -9,7 +9,6 @@ using ReTest realtype ∈ [Float64], # Currently only tested against Float64 (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, - :NormalLogNormalFullRank => normallognormal_fullrank, ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), @@ -17,7 +16,7 @@ using ReTest ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), + # :ReverseDiff => AutoReverseDiff(), # :Zygote => AutoZygote(), # :Enzyme => AutoEnzyme(), ) @@ -32,19 +31,10 @@ using ReTest b = Bijectors.bijector(model) b⁻¹ = inverse(b) + μ₀ = zeros(realtype, n_dims) + L₀ = Diagonal(ones(realtype, n_dims)) - μ₀ = zeros(realtype, n_dims) - L₀ = if is_meanfield - FillArrays.Eye(n_dims) |> Diagonal - else - FillArrays.Eye(n_dims) |> Matrix |> LowerTriangular - end - - q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀) - else - VIFullRankGaussian(μ₀, L₀) - end + q₀ = TuringDiagMvNormal(μ₀, diag(L₀)) obj = objective(model, b⁻¹, 10) @@ -58,8 +48,8 @@ using ReTest adbackend = adbackend, ) - μ = q.location - L = q.scale + μ = mean(q) + L = sqrt(cov(q)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ Δλ₀/T^(1/4) @@ -76,8 +66,8 @@ using ReTest rng = rng, adbackend = adbackend, ) - μ = q.location - L = q.scale + μ = mean(q) + L = sqrt(cov(q)) rng_repl = Philox4x(UInt64, seed, 8) q, stats, _, _ = optimize( @@ -87,8 +77,8 @@ using ReTest rng = rng_repl, adbackend = adbackend, ) - μ_repl = q.location - L_repl = q.scale + μ_repl = mean(q) + L_repl = sqrt(cov(q)) @test μ == μ_repl @test L == L_repl end diff --git a/test/distributions.jl b/test/distributions.jl deleted file mode 100644 index 175cc96b..00000000 --- a/test/distributions.jl +++ /dev/null @@ -1,96 +0,0 @@ - -using ReTest -using Distributions: _logpdf - -@testset "distributions" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for - basedist = [:gaussian], - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64] - - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - n_dims = 10 - n_montecarlo = 1000_000 - - μ = randn(rng, realtype, n_dims) - L = if covtype == :fullrank - tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular - else - Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) - end - Σ = L*L' - - q = if covtype == :fullrank && basedist == :gaussian - VIFullRankGaussian(μ, L) - elseif covtype == :meanfield && basedist == :gaussian - VIMeanFieldGaussian(μ, L) - end - q_true = if basedist == :gaussian - MvNormal(μ, Σ) - end - - @testset "logpdf" begin - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - z = rand(rng, q) - @test eltype(z) == realtype - @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) - @test _logpdf(q, z) ≈ _logpdf(q_true, z) rtol=realtype(1e-2) - @test eltype(logpdf(q, z)) == realtype - @test eltype(_logpdf(q, z)) == realtype - end - - @testset "entropy" begin - @test eltype(entropy(q)) == realtype - @test entropy(q) ≈ entropy(q_true) - end - - @testset "sampling" begin - @testset "rand" begin - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - z_samples = mapreduce(x -> rand(rng, q), hcat, 1:n_montecarlo) - @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - end - - @testset "rand batch" begin - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - z_samples = rand(rng, q, n_montecarlo) - @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - end - - @testset "rand!" begin - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) - rand!(rng, q, z_samples) - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - end - end - end - - @testset "Diagonal destructure" for - n_dims = 10 - μ = zeros(n_dims) - L = ones(n_dims) - q = VIMeanFieldGaussian(μ, L |> Diagonal) - λ, re = Optimisers.destructure(q) - - @test length(λ) == 2*n_dims - @test q == re(λ) - end -end diff --git a/test/optimize.jl b/test/optimize.jl index 2369432c..56ca63c0 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -12,9 +12,7 @@ using ReTest # Global Test Configurations b⁻¹ = Bijectors.bijector(model) |> inverse - μ₀ = zeros(Float64, n_dims) - L₀ = ones(Float64, n_dims) |> Diagonal - q₀ = VIMeanFieldGaussian(μ₀, L₀) + q₀ = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) obj = ADVI(model, 10; invbij=b⁻¹) adbackend = AutoForwardDiff() diff --git a/test/runtests.jl b/test/runtests.jl index 127503be..fd68ed79 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,10 @@ using SimpleUnPack: @unpack using FillArrays using PDMats +using Functors +using DistributionsAD +@functor TuringDiagMvNormal + using Bijectors using LogDensityProblems using Optimisers @@ -33,7 +37,6 @@ include("models/normallognormal.jl") # Tests include("ad.jl") -include("distributions.jl") include("advi_locscale.jl") include("optimize.jl") From d2ae29fffcbfacad59268b1c6835b43858e138db Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 02:21:56 -0400 Subject: [PATCH 141/206] remove doc action for now --- .github/workflows/CI.yml | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 26f6876f..7ba573a1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,30 +61,3 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info - docs: - name: Documentation - runs-on: ubuntu-latest - permissions: - contents: write - statuses: write - steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 - with: - version: '1' - - name: Configure doc environment - run: | - julia --project=docs/ -e ' - using Pkg - Pkg.develop(PackageSpec(path=pwd())) - Pkg.instantiate()' - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-docdeploy@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - run: | - julia --project=docs -e ' - using Documenter: DocMeta, doctest - using AdvancedVI - DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) - doctest(AdvancedVI)' From fb84e3d3aa0e383c94fe88e0a8b33c845f916cd7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 02:27:37 -0400 Subject: [PATCH 142/206] revert README for now --- README.md | 301 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 222 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 695e9ed9..f0bf6cc1 100644 --- a/README.md +++ b/README.md @@ -1,108 +1,251 @@ - # AdvancedVI.jl -[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. -VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. -`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. -The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. -For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) by simply converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. +A library for variational Bayesian inference in Julia. + +At the time of writing (05/02/2020), implementations of the variational inference (VI) interface and some algorithms are implemented in [Turing.jl](https://github.com/TuringLang/Turing.jl). The idea is to soon separate the VI functionality in Turing.jl out and into this package. + +The purpose of this package will then be to provide a common interface together with implementations of standard algorithms and utilities with the goal of ease of use and the ability for other packages, e.g. Turing.jl, to write a light wrapper around AdvancedVI.jl for integration. + +As an example, in Turing.jl we support automatic differentiation variational inference (ADVI) but really the only piece of code tied into the Turing.jl is the conversion of a `Turing.Model` to a `logjoint(z)` function which computes `z ↦ log p(x, z)`, with `x` denoting the observations embedded in the `Turing.Model`. As long as this `logjoint(z)` method is compatible with some AD framework, e.g. `ForwardDiff.jl` or `Zygote.jl`, this is all we need from Turing.jl to be able to perform ADVI! + +## [WIP] Interface +- `vi`: the main interface to the functionality in this package + - `vi(model, alg)`: only used when `alg` has a default variational posterior which it will provide. + - `vi(model, alg, q::VariationalPosterior, θ)`: `q` represents the family of variational distributions and `θ` is the initial parameters "indexing" the starting distribution. This assumes that there exists an implementation `Variational.update(q, θ)` which returns the variational posterior corresponding to parameters `θ`. + - `vi(model, alg, getq::Function, θ)`: here `getq(θ)` is a function returning a `VariationalPosterior` corresponding to `θ`. +- `optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())` +- `grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)` + - Different combinations of variational objectives (`vo`), VI methods (`alg`), and variational posteriors (`q`) might use different gradient estimators. `grad!` allows us to specify these different behaviors. ## Examples +### Variational Inference +A very simple generative model is the following -`AdvancedVI` expects a `LogDensityProblem`. -For example, for the normal-log-normal model: + μ ~ 𝒩(0, 1) + xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n -$$ -\begin{aligned} -x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), -\end{aligned} -$$ +where μ and xᵢ are some ℝᵈ vectors and 𝒩 denotes a d-dimensional multivariate Normal distribution. -a `LogDensityProblem` can be implemented as +Given a set of `n` observations `[x₁, …, xₙ]` we're interested in finding the distribution `p(μ∣x₁, …, xₙ)` over the mean `μ`. We can obtain (an approximation to) this distribution that using AdvancedVI.jl! + +First we generate some observations and set up the problem: ```julia -using LogDensityProblems +julia> using Distributions -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end +julia> d = 2; n = 100; -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end +julia> observations = randn((d, n)); # 100 observations from 2D 𝒩(0, 1) -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end +julia> # Define generative model + # μ ~ 𝒩(0, 1) + # xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n + prior(μ) = logpdf(MvNormal(ones(d)), μ) +prior (generic function with 1 method) -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end -``` +julia> likelihood(x, μ) = sum(logpdf(MvNormal(μ, ones(d)), x)) +likelihood (generic function with 1 method) + +julia> logπ(μ) = likelihood(observations, μ) + prior(μ) +logπ (generic function with 1 method) -Since the support of `x` is constrained to be $$\mathbb{R}_+$$, and inference is best done in the unconstrained space $$\mathbb{R}_+$$, we need to use a *bijector* to match support. -This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, 2015). +julia> logπ(randn(2)) # <= just checking that it works +-311.74132761437653 +``` +Now there are mainly two different ways of specifying the approximate posterior (and its family). The first is by providing a mapping from distribution parameters to the distribution `θ ↦ q(⋅∣θ)`: ```julia -using Bijectors +julia> using DistributionsAD, AdvancedVI -function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end +julia> # Using a function z ↦ q(⋅∣z) + getq(θ) = TuringDiagMvNormal(θ[1:d], exp.(θ[d + 1:4])) +getq (generic function with 1 method) ``` +Then we make the choice of algorithm, a subtype of `VariationalInference`, +```julia +julia> # Perform VI + advi = ADVI(10, 10_000) +ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 10000) +``` +And finally we can perform VI! The usual inferface is to call `vi` which behind the scenes takes care of the optimization and returns the resulting variational posterior: +```julia +julia> q = vi(logπ, advi, getq, randn(4)) +[ADVI] Optimizing...100% Time: 0:00:01 +TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[0.16282745378074515, 0.15789310089462574], σ=[0.09519377533754399, 0.09273176907111745]) +``` +Let's have a look at the resulting ELBO: +```julia +julia> AdvancedVI.elbo(advi, q, logπ, 1000) +-287.7866366886285 +``` +Unfortunately, the *final* value of the ELBO is not always a very good diagnostic, though the ELBO is an important metric to keep an eye on during training since an *increase* in the ELBO means we're going in the right direction. Luckily, this is such a simple problem that we can indeed obtain a closed form solution! Because we're lazy (at least I am), we'll let [ConjugatePriors.jl](https://github.com/JuliaStats/ConjugatePriors.jl) do this for us: +```julia +julia> # True posterior + using ConjugatePriors + +julia> pri = MvNormal(zeros(2), ones(2)); -A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. +julia> true_posterior = posterior((pri, pri.Σ), MvNormal, observations) +DiagNormal( +dim: 2 +μ: [0.1746546592601148, 0.16457110079543008] +Σ: [0.009900990099009901 0.0; 0.0 0.009900990099009901] +) +``` +Comparing to our variational approximation, this looks pretty good! Worth noting that in this particular case the variational posterior seems to overestimate the variance. -Let us instantiate a random normal-log-normal model. +To conclude, let's make a somewhat pretty picture: ```julia -using LinearAlgebra - -n_dims = 10 -μ_x = randn() -σ_x = exp.(randn()) -μ_y = randn(n_dims) -σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) +julia> using Plots + +julia> p_samples = rand(true_posterior, 10_000); q_samples = rand(q, 10_000); + +julia> p1 = histogram(p_samples[1, :], label="p"); histogram!(q_samples[1, :], alpha=0.7, label="q") + +julia> title!(raw"$\mu_1$") + +julia> p2 = histogram(p_samples[2, :], label="p"); histogram!(q_samples[2, :], alpha=0.7, label="q") + +julia> title!(raw"$\mu_2$") + +julia> plot(p1, p2) ``` +![Histogram](hist.png?raw=true) + +### Simple example: using Advanced.jl to directly minimize the KL-divergence between two distributions `p(z)` and `q(z)` +In VI we aim to approximate the true posterior `p(z ∣ x)` by some approximate variational posterior `q(z)` by maximizing the ELBO: + + ELBO(q) = 𝔼_q[log p(x, z) - log q(z)] + +Observe that we can express the ELBO as the negative KL-divergence between `p(x, ⋅)` and `q(⋅)`: + + ELBO(q) = - 𝔼_q[log (q(z) / p(x, z))] + = - KL(q(⋅) || p(x, ⋅)) + +So if we apply VI to something that isn't an actual posterior, i.e. there's no data involved and we write `p(z ∣ x) = p(z)`, we're really just minimizing the KL-divergence between the distributions. + +Therefore, we can try out `AdvancedVI.jl` real quick by applying using the interface to minimize the KL-divergence between two distributions: -ADVI can be used as follows: ```julia -using Optimisers -using ADTypes, ForwardDiff -import AdvancedVI as AVI - -b = Bijectors.bijector(model) -b⁻¹ = inverse(b) - -# ADVI objective -objective = AVI.ADVI(model, 10; invbij=b⁻¹) - -# Mean-field Gaussian variational family -d = LogDensityProblems.dimension(model) -μ = randn(d) -L = Diagonal(ones(d)) -q = AVI.VIMeanFieldGaussian(μ, L) - -# Run inference -n_max_iter = 10^4 -q, stats, _ = AVI.optimize( - objective, - q, - n_max_iter; - adbackend = ADTypes.AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) +julia> using Distributions, DistributionsAD, AdvancedVI + +julia> # Target distribution + p = MvNormal(ones(2)) +ZeroMeanDiagNormal( +dim: 2 +μ: [0.0, 0.0] +Σ: [1.0 0.0; 0.0 1.0] ) -# Evaluate final ELBO with 10^3 Monte Carlo samples -objective(q; n_samples=10^3) +julia> logπ(z) = logpdf(p, z) +logπ (generic function with 1 method) + +julia> # Make a choice of VI algorithm + advi = ADVI(10, 1000) +ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 1000) +``` +Now there are two different ways of specifying the approximate posterior (and its family); the first is by providing a mapping from parameters to distribution `θ ↦ q(⋅∣θ)`: +```julia +julia> # Using a function z ↦ q(⋅∣z) + getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) +getq (generic function with 1 method) + +julia> # Perform VI + q = vi(logπ, advi, getq, randn(4)) +┌ Info: [ADVI] Should only be seen once: optimizer created for θ +└ objectid(θ) = 0x5ddb564423896704 +[ADVI] Optimizing...100% Time: 0:00:01 +TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[-0.012691337868985757, -0.0004442434543332919], σ=[1.0334797673569802, 0.9957355128767893]) +``` +Or we can check the ELBO (which in this case since, as mentioned, doesn't involve data, is the negative KL-divergence): +```julia +julia> AdvancedVI.elbo(advi, q, logπ, 1000) # empirical estimate +0.08031049170093245 +``` +It's worth noting that the actual value of the ELBO doesn't really tell us too much about the quality of fit. In this particular case, because we're *directly* minimizing the KL-divergence, we can only say something useful if we reach 0, in which case we have obtained the true distribution. + +Let's just quickly check the mean-squared error between the `log p(z)` and `log q(z)` for a random set of samples from the target `p`: +```julia +julia> zs = rand(p, 100); + +julia> mean(abs2, logpdf(q, zs) - logpdf(p, zs)) +0.0014889109427524852 +``` +That doesn't look too bad! + +## Implementing your own training loop +Sometimes it might be convenient to roll your own training loop rather than using `vi(...)`. Here's some psuedo-code for how one would do that when used together with Turing.jl: + +```julia +using Turing, AdvancedVI, DiffResults +using Turing: Variational + +using ProgressMeter + +# Assuming you have an instance of a Turing model (`model`) + +# 1. Create log-joint needed for ELBO evaluation +logπ = Variational.make_logjoint(model) + +# 2. Define objective +variational_objective = Variational.ELBO() + +# 3. Optimizer +optimizer = Variational.DecayedADAGrad() + +# 4. VI-algorithm +alg = ADVI(10, 1000) + +# 5. Variational distribution +function getq(θ) + # ... +end + +# 6. [OPTIONAL] Implement convergence criterion +function hasconverged(args...) + # ... +end + +# 7. [OPTIONAL] Implement a callback for tracking stats +function callback(args...) + # ... +end + +# 8. Train +converged = false +step = 1 + +prog = ProgressMeter.Progress(num_steps, 1) + +diff_results = DiffResults.GradientResult(θ_init) + +while (step ≤ num_steps) && !converged + # 1. Compute gradient and objective value; results are stored in `diff_results` + AdvancedVI.grad!(variational_objective, alg, getq, model, diff_results) + + # 2. Extract gradient from `diff_result` + ∇ = DiffResults.gradient(diff_result) + + # 3. Apply optimizer, e.g. multiplying by step-size + Δ = apply!(optimizer, θ, ∇) + + # 4. Update parameters + @. θ = θ - Δ + + # 5. Do whatever analysis you want + callback(args...) + + # 6. Update + converged = hasconverged(...) # or something user-defined + step += 1 + + ProgressMeter.next!(prog) +end ``` ## References +- Jordan, Michael I., Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. "An introduction to variational methods for graphical models." Machine learning 37, no. 2 (1999): 183-233. +- Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. "Variational inference: A review for statisticians." Journal of the American statistical Association 112, no. 518 (2017): 859-877. - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015. +- Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882. +- Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003. + From 0575b23ee90f9677a2d4db9482d9fcb4feeea846 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 25 Aug 2023 19:29:56 +0100 Subject: [PATCH 143/206] refactor remove redundant `rng` argument to `ADVI`, improve docs --- src/objectives/elbo/advi.jl | 25 +++++++++++++++++++++---- src/objectives/elbo/entropy.jl | 3 +++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index f9a61d81..ef0ac50d 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -3,6 +3,20 @@ ADVI(prob, n_samples; kwargs...) Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. +This computes the evidence lower-bound (ELBO) through the ADVI formulation: +```math +\\begin{aligned} +\\mathrm{ADVI}\\left(\\lambda\\right) +&\\triangleq +\\mathbb{E}_{\\eta \\sim q_{\\lambda}}\\left[ + \\log \\pi\\left( \\phi^{-1}\\left( \\eta \\right) \\right) + + + \\log \\lvert J_{\\phi^{-1}}\\left(\\eta\\right) \\rvert +\\right] ++ \\mathbb{H}\\left(q_{\\lambda}\\right), +\\end{aligned} +``` +where ``\\phi^{-1}`` is an "inverse bijector." # Arguments - `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. @@ -11,13 +25,17 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) # Keyword Arguments - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) - `cv`: A control variate. -- `invbij`: A bijective mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) +- `invbij`: An inverse bijective mapping that matches the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) # Requirements - ``q_{\\lambda}`` implements `rand`. - `logdensity(prob)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. + +# References +* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. """ struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective prob ::P @@ -49,7 +67,6 @@ Base.show(io::IO, advi::ADVI) = init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing function (advi::ADVI)( - rng::AbstractRNG, q_η::ContinuousMultivariateDistribution, ηs ::AbstractMatrix ) @@ -81,7 +98,7 @@ function (advi::ADVI)( n_samples::Int = advi.n_samples ) ηs = rand(rng, q_η, n_samples) - advi(rng, q_η, ηs) + advi(q_η, ηs) end function estimate_gradient( @@ -96,7 +113,7 @@ function estimate_gradient( f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) - -advi(rng, q_η, ηs) + -advi(q_η, ηs) end value_and_gradient!(adbackend, f, λ, out) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 97ccda29..e6212c46 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -21,6 +21,9 @@ The "sticking the landing" entropy estimator. # Requirements - `q` implements `logpdf`. - `logpdf(q, η)` must be differentiable by the selected AD framework. + +# References +* Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. """ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end From ecc52428b8ed79f18525fbc14cbf7d6632f9cac9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 25 Aug 2023 19:30:23 +0100 Subject: [PATCH 144/206] fix wrong whitespace in tests --- test/advi_locscale.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index f2ce94a5..93ece412 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -32,9 +32,9 @@ using ReTest b = Bijectors.bijector(model) b⁻¹ = inverse(b) μ₀ = zeros(realtype, n_dims) - L₀ = Diagonal(ones(realtype, n_dims)) + L₀ = Diagonal(ones(realtype, n_dims)) - q₀ = TuringDiagMvNormal(μ₀, diag(L₀)) + q₀ = TuringDiagMvNormal(μ₀, diag(L₀)) obj = objective(model, b⁻¹, 10) @@ -48,8 +48,8 @@ using ReTest adbackend = adbackend, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q) + L = sqrt(cov(q)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ Δλ₀/T^(1/4) @@ -66,8 +66,8 @@ using ReTest rng = rng, adbackend = adbackend, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q) + L = sqrt(cov(q)) rng_repl = Philox4x(UInt64, seed, 8) q, stats, _, _ = optimize( @@ -77,8 +77,8 @@ using ReTest rng = rng_repl, adbackend = adbackend, ) - μ_repl = mean(q) - L_repl = sqrt(cov(q)) + μ_repl = mean(q) + L_repl = sqrt(cov(q)) @test μ == μ_repl @test L == L_repl end From 1cff3df3af793b684934107521031a55df222419 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 25 Aug 2023 19:56:56 +0100 Subject: [PATCH 145/206] refactor `estimate_gradient` to `estimate_gradient!`, add docs --- src/AdvancedVI.jl | 55 +++++++++++++++++++++++++++++++++---- src/objectives/elbo/advi.jl | 10 ++----- src/optimize.jl | 2 +- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index da8b05bb..609266b4 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -40,18 +40,63 @@ using StatsBase: entropy out::DiffResults.MutableDiffResult ) -Compute the value and gradient of a function `f` at `θ` using the automatic -differentiation backend `ad`. The result is stored in `out`. -The function `f` must return a scalar value. The gradient is stored in `out` as a -vector of the same length as `θ`. +Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. +The result is stored in `out`. +The function `f` must return a scalar value. The gradient is stored in `out` as a vector of the same length as `θ`. """ function value_and_gradient! end # estimators +""" + abstract type AbstractVariationalObjective end + +An VI algorithm supported by `AdvancedVI` should implement a subtype of `AbstractVariationalObjective`. +Furthermore, it should implement the functions `init`, `estimate_gradient`. +""" abstract type AbstractVariationalObjective end +""" + init( + rng::AbstractRNG, + obj::AbstractVariationalObjective, + λ::AbstractVector, + restructure + ) + +Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +This is relevant only if `obj` is stateful. + +!!! warning + This is an internal function. Thus, the signature is subject to change without + notice. +""" function init end -function estimate_gradient end + +""" + estimate_gradient!( + rng ::AbstractRNG, + adbackend ::AbstractADType, + obj ::AbstractVariationalObjective, + obj_state, + λ ::AbstractVector, + restructure, + out ::DiffResults.MutableDiffResult + ) + +Estimate (possibly stochastic) gradients of the objective `obj` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. +The estimated objective value and gradient are then stored in `out`. +If the objective is stateful, `obj_state` is its previous state. + +# Returns +- `out`: The `MutableDiffResult` containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective estimator. +- `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`) + +!!! warning + This is an internal function. Thus, the signature is subject to change without + notice. +""" +function estimate_gradient! end # ADVI-specific interfaces abstract type AbstractEntropyEstimator end diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index ef0ac50d..0e373f9a 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -85,12 +85,8 @@ end n_samples::Int = advi.n_samples ) -Evaluate the ELBO using the ADVI formulation. - -# Arguments -- `q_η`: Variational approximation before applying a bijector (unconstrained support). -- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. - +Estimate the ELBO of the variational approximation `q_η` using the ADVI +formulation using `n_samples` number of Monte Carlo samples. """ function (advi::ADVI)( q_η ::ContinuousMultivariateDistribution; @@ -101,7 +97,7 @@ function (advi::ADVI)( advi(q_η, ηs) end -function estimate_gradient( +function estimate_gradient!( rng ::AbstractRNG, adbackend ::AbstractADType, advi ::ADVI, diff --git a/src/optimize.jl b/src/optimize.jl index 54e7ace0..f21e757a 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -77,7 +77,7 @@ function optimize( for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, obj_state, stat′ = estimate_gradient( + grad_buf, obj_state, stat′ = estimate_gradient!( rng, adbackend, objective, obj_state, λ, restructure, grad_buf) stat = merge(stat, stat′) From 54acd8af483af503d17997d91cb093ed420c0140 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 25 Aug 2023 20:05:07 +0100 Subject: [PATCH 146/206] refactor add default `init` impl, update docs --- src/AdvancedVI.jl | 17 ++++++++++++----- src/objectives/elbo/advi.jl | 2 -- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 609266b4..db433a67 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -42,7 +42,8 @@ using StatsBase: entropy Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. The result is stored in `out`. -The function `f` must return a scalar value. The gradient is stored in `out` as a vector of the same length as `θ`. +The function `f` must return a scalar value. +The gradient is stored in `out` as a vector of the same length as `θ`. """ function value_and_gradient! end @@ -51,7 +52,8 @@ function value_and_gradient! end abstract type AbstractVariationalObjective end An VI algorithm supported by `AdvancedVI` should implement a subtype of `AbstractVariationalObjective`. -Furthermore, it should implement the functions `init`, `estimate_gradient`. +Furthermore, it should implement the functions `estimate_gradient`. +If the estimator is stateful, it can implement `init` to initialize the state. """ abstract type AbstractVariationalObjective end @@ -64,13 +66,18 @@ abstract type AbstractVariationalObjective end ) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. -This is relevant only if `obj` is stateful. +This function needs to be implemented only if `obj` is stateful. !!! warning This is an internal function. Thus, the signature is subject to change without notice. """ -function init end +init( + rng::AbstractRNG, + obj::AbstractVariationalObjective, + λ::AbstractVector, + restructure +) = nothing """ estimate_gradient!( @@ -85,7 +92,7 @@ function init end Estimate (possibly stochastic) gradients of the objective `obj` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. The estimated objective value and gradient are then stored in `out`. -If the objective is stateful, `obj_state` is its previous state. +If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`. # Returns - `out`: The `MutableDiffResult` containing the objective value and gradient estimates. diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 0e373f9a..5a3ce96e 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -64,8 +64,6 @@ end Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") -init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing - function (advi::ADVI)( q_η::ContinuousMultivariateDistribution, ηs ::AbstractMatrix From 61a2272cfb01d3052595f23fabb4cf85ba81b320 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 26 Aug 2023 21:24:55 +0100 Subject: [PATCH 147/206] merge (manually) commit ff32ac642d6aa3a08d371ed895aa6b4026b06b92 --- src/optimize.jl | 64 +++++++++++++++++++++++++----------------------- test/optimize.jl | 31 ++++++++++++++++++++--- 2 files changed, 61 insertions(+), 34 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index f21e757a..ea2fd5a1 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -8,7 +8,8 @@ end objective ::AbstractVariationalObjective, restructure, λ₀ ::AbstractVector{<:Real}, - n_max_iter ::Int; + n_max_iter ::Int, + objargs...; kwargs... ) @@ -17,7 +18,8 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie optimize( objective ::AbstractVariationalObjective, q, - n_max_iter::Int; + n_max_iter::Int, + objargs...; kwargs... ) @@ -29,36 +31,34 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `restruct`: Function that reconstructs the variational approximation from the flattened parameters. - `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. - `n_max_iter`: Maximum number of iterations. +- `objargs...`: Arguments to be passed to `objective`. +- `kwargs...`: Additional keywoard arguments. (See below.) # Keyword Arguments - `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.) - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. +- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.) - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) - -When resuming from the state of a previous run, use the following keyword arguments: -- `opt_state`: Initial state of the optimizer. -- `obj_state`: Initial state of the objective. +- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.) # Returns - `λ`: Variational parameters optimizing the variational objective. -- `stats`: Statistics gathered during inference. -- `opt_state`: Final state of the optimiser. -- `obj_state`: Final state of the objective. +- `logstats`: Statistics and logs gathered during optimization. +- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. """ function optimize( objective ::AbstractVariationalObjective, restructure, λ₀ ::AbstractVector{<:Real}, - n_max_iter ::Int; - adbackend::AbstractADType, + n_max_iter ::Int, + objargs...; + adbackend ::AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::AbstractRNG = default_rng(), show_progress::Bool = true, - opt_state = nothing, - obj_state = nothing, + state ::NamedTuple = NamedTuple(), callback! = nothing, prog = ProgressMeter.Progress( n_max_iter; @@ -66,37 +66,39 @@ function optimize( barlen = 31, showspeed = true, enabled = show_progress - ) + ) ) - λ = copy(λ₀) - opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ) : opt_state - obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state - grad_buf = DiffResults.GradientResult(λ) - stats = NamedTuple[] + λ = copy(λ₀) + opt_st = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ) + obj_st = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure) + grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) + logstats = NamedTuple[] for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, obj_state, stat′ = estimate_gradient!( - rng, adbackend, objective, obj_state, λ, restructure, grad_buf) + grad_buf, obj_st, stat′ = estimate_gradient( + rng, adbackend, objective, obj_st, + λ, restructure, grad_buf; objargs... + ) stat = merge(stat, stat′) - g = DiffResults.gradient(grad_buf) - opt_state, λ = Optimisers.update!(opt_state, λ, g) - stat′ = (iteration = t,) - stat = merge(stat, stat′) + g = DiffResults.gradient(grad_buf) + opt_st, λ = Optimisers.update!(opt_st, λ, g) if !isnothing(callback!) - stat′ = callback!(; obj_state, stat, restructure, λ, g) + stat′ = callback!(; stat, restructure, λ, g) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end @debug "Iteration $t" stat... pm_next!(prog, stat) - push!(stats, stat) + push!(logstats, stat) end - λ, map(identity, stats), opt_state, obj_state + state = (opt=opt_st, obj=obj_st) + logstats = map(identity, logstats) + λ, logstats, state end function optimize(objective ::AbstractVariationalObjective, @@ -104,8 +106,8 @@ function optimize(objective ::AbstractVariationalObjective, n_max_iter::Int; kwargs...) λ, restructure = Optimisers.destructure(q₀) - λ, stats, opt_state, obj_state = optimize( + λ, logstats, state = optimize( objective, restructure, λ, n_max_iter; kwargs... ) - restructure(λ), stats, opt_state, obj_state + restructure(λ), logstats, state end diff --git a/test/optimize.jl b/test/optimize.jl index 56ca63c0..78d07d00 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -19,7 +19,7 @@ using ReTest optimizer = Optimisers.Adam(1e-2) rng = Philox4x(UInt64, seed, 8) - q_ref, stats_ref, _, _ = optimize( + q_ref, stats_ref, _ = optimize( obj, q₀, T; optimizer, show_progress = false, @@ -32,7 +32,7 @@ using ReTest λ₀, re = Optimisers.destructure(q₀) rng = Philox4x(UInt64, seed, 8) - λ, stats, _, _ = optimize( + λ, stats, _ = optimize( obj, re, λ₀, T; optimizer, show_progress = false, @@ -52,7 +52,7 @@ using ReTest end rng = Philox4x(UInt64, seed, 8) - _, stats, _, _ = optimize( + _, stats, _ = optimize( obj, q₀, T; show_progress = false, rng, @@ -61,4 +61,29 @@ using ReTest ) @test [stat.test_value for stat ∈ stats] == test_values end + + @testset "warm start" begin + rng = Philox4x(UInt64, seed, 8) + + T_first = div(T,2) + T_last = T - T_first + + q_first, _, state = optimize( + obj, q₀, T_first; + optimizer, + show_progress = false, + rng, + adbackend + ) + + q, stats, _ = optimize( + obj, q_first, T_last; + optimizer, + show_progress = false, + state, + rng, + adbackend + ) + @test q == q_ref + end end From c56d29ef1c2954673b7941fce6c3d8d664fe020c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 26 Aug 2023 22:03:36 +0100 Subject: [PATCH 148/206] fix test for new interface, change interface for `optimize`, `advi` --- src/AdvancedVI.jl | 3 +- src/objectives/elbo/advi.jl | 81 ++++++++++++++++++++----------------- src/optimize.jl | 21 ++++++---- test/advi_locscale.jl | 33 ++++++++------- test/optimize.jl | 21 +++++----- 5 files changed, 87 insertions(+), 72 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index db433a67..91f714e4 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -82,6 +82,7 @@ init( """ estimate_gradient!( rng ::AbstractRNG, + prob, adbackend ::AbstractADType, obj ::AbstractVariationalObjective, obj_state, @@ -90,7 +91,7 @@ init( out ::DiffResults.MutableDiffResult ) -Estimate (possibly stochastic) gradients of the objective `obj` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. +Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. The estimated objective value and gradient are then stored in `out`. If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`. diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 5a3ce96e..1ce57371 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,6 +1,6 @@ """ - ADVI(prob, n_samples; kwargs...) + ADVI(n_samples; kwargs...) Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. This computes the evidence lower-bound (ELBO) through the ADVI formulation: @@ -19,17 +19,14 @@ This computes the evidence lower-bound (ELBO) through the ADVI formulation: where ``\\phi^{-1}`` is an "inverse bijector." # Arguments -- `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. - `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.) # Keyword Arguments - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) -- `cv`: A control variate. -- `invbij`: An inverse bijective mapping that matches the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) # Requirements - ``q_{\\lambda}`` implements `rand`. -- `logdensity(prob)` must be differentiable by the selected AD backend. +- The target `logdensity(prob)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. @@ -37,27 +34,12 @@ Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. * Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. * Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. """ -struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - prob ::P - invbij ::B +struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective entropy ::EntropyEst n_samples::Int - function ADVI(prob, - n_samples::Int; - entropy ::AbstractEntropyEstimator = ClosedFormEntropy(), - invbij = Bijectors.identity) - cap = LogDensityProblems.capabilities(prob) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - new{typeof(prob), typeof(invbij), typeof(entropy)}( - prob, invbij, entropy, n_samples - ) + function ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) + new{typeof(entropy)}(entropy, n_samples) end end @@ -65,38 +47,64 @@ Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") function (advi::ADVI)( - q_η::ContinuousMultivariateDistribution, + prob, + q ::ContinuousMultivariateDistribution, + zs::AbstractMatrix +) + 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs)) + ℍ = advi.entropy(q, zs) + 𝔼ℓ + ℍ +end + +function (advi::ADVI)( + prob, + q_trans::Bijectors.TransformedDistribution, ηs ::AbstractMatrix ) + @unpack dist, transform = q_trans + q = dist + b⁻¹ = transform 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) - LogDensityProblems.logdensity(advi.prob, zᵢ) + logdetjacᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, ηᵢ) + LogDensityProblems.logdensity(prob, zᵢ) + logdetjacᵢ end - ℍ = advi.entropy(q_η, ηs) + ℍ = advi.entropy(q, ηs) 𝔼ℓ + ℍ end """ (advi::ADVI)( - q_η::ContinuousMultivariateDistribution; + prob, q; rng::AbstractRNG = Random.default_rng(), n_samples::Int = advi.n_samples ) -Estimate the ELBO of the variational approximation `q_η` using the ADVI -formulation using `n_samples` number of Monte Carlo samples. +Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples. """ function (advi::ADVI)( - q_η ::ContinuousMultivariateDistribution; + prob, + q ::ContinuousMultivariateDistribution; + rng ::AbstractRNG = default_rng(), + n_samples::Int = advi.n_samples +) + zs = rand(rng, q, n_samples) + advi(q, zs) +end + +function (advi::ADVI)( + prob, + q_trans ::Bijectors.TransformedDistribution; rng ::AbstractRNG = default_rng(), n_samples::Int = advi.n_samples ) - ηs = rand(rng, q_η, n_samples) - advi(q_η, ηs) + q = q_trans.dist + ηs = rand(rng, q, n_samples) + advi(q_trans, ηs) end function estimate_gradient!( rng ::AbstractRNG, + prob, adbackend ::AbstractADType, advi ::ADVI, est_state, @@ -105,9 +113,10 @@ function estimate_gradient!( out ::DiffResults.MutableDiffResult ) f(λ′) = begin - q_η = restructure(λ′) - ηs = rand(rng, q_η, advi.n_samples) - -advi(q_η, ηs) + q_trans = restructure(λ′) + q = q_trans.dist + ηs = rand(rng, q, advi.n_samples) + -advi(prob, q_trans, ηs) end value_and_gradient!(adbackend, f, λ, out) diff --git a/src/optimize.jl b/src/optimize.jl index ea2fd5a1..5425d938 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -5,6 +5,7 @@ end """ optimize( + prob, objective ::AbstractVariationalObjective, restructure, λ₀ ::AbstractVector{<:Real}, @@ -13,9 +14,10 @@ end kwargs... ) -Optimize the variational objective `objective` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`. +Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`. optimize( + prob, objective ::AbstractVariationalObjective, q, n_max_iter::Int, @@ -23,7 +25,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie kwargs... ) -Optimize the variational objective `objective` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface. +Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface. # Arguments - `objective`: Variational Objective. @@ -49,6 +51,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. """ function optimize( + prob, objective ::AbstractVariationalObjective, restructure, λ₀ ::AbstractVector{<:Real}, @@ -77,9 +80,9 @@ function optimize( for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, obj_st, stat′ = estimate_gradient( - rng, adbackend, objective, obj_st, - λ, restructure, grad_buf; objargs... + grad_buf, obj_st, stat′ = estimate_gradient!( + rng, prob, adbackend, objective, obj_st, + λ, restructure, grad_buf, objargs... ) stat = merge(stat, stat′) @@ -101,13 +104,15 @@ function optimize( λ, logstats, state end -function optimize(objective ::AbstractVariationalObjective, +function optimize(prob, + objective ::AbstractVariationalObjective, q₀, - n_max_iter::Int; + n_max_iter::Int, + objargs...; kwargs...) λ, restructure = Optimisers.destructure(q₀) λ, logstats, state = optimize( - objective, restructure, λ, n_max_iter; kwargs... + prob, objective, restructure, λ, n_max_iter, objargs...; kwargs... ) restructure(λ), logstats, state end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 93ece412..85cfea71 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -11,8 +11,8 @@ using ReTest :NormalLogNormalMeanField => normallognormal_meanfield, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹, entropy = StickingTheLandingEntropy()), + :ADVIClosedFormEntropy => ADVI(10), + :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), @@ -34,22 +34,21 @@ using ReTest μ₀ = zeros(realtype, n_dims) L₀ = Diagonal(ones(realtype, n_dims)) - q₀ = TuringDiagMvNormal(μ₀, diag(L₀)) - - obj = objective(model, b⁻¹, 10) + q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) + q₀_z = Bijectors.transformed(q₀_η, b⁻¹) @testset "convergence" begin Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _, _ = optimize( - obj, q₀, T; + q, stats, _ = optimize( + model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng, adbackend = adbackend, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q.dist) + L = sqrt(cov(q.dist)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ Δλ₀/T^(1/4) @@ -59,26 +58,26 @@ using ReTest @testset "determinism" begin rng = Philox4x(UInt64, seed, 8) - q, stats, _, _ = optimize( - obj, q₀, T; + q, stats, _ = optimize( + model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng, adbackend = adbackend, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q.dist) + L = sqrt(cov(q.dist)) rng_repl = Philox4x(UInt64, seed, 8) - q, stats, _, _ = optimize( - obj, q₀, T; + q, stats, _ = optimize( + model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng_repl, adbackend = adbackend, ) - μ_repl = mean(q) - L_repl = sqrt(cov(q)) + μ_repl = mean(q.dist) + L_repl = sqrt(cov(q.dist)) @test μ == μ_repl @test L == L_repl end diff --git a/test/optimize.jl b/test/optimize.jl index 78d07d00..2af56c1f 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -11,16 +11,17 @@ using ReTest @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations - b⁻¹ = Bijectors.bijector(model) |> inverse - q₀ = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - obj = ADVI(model, 10; invbij=b⁻¹) + b⁻¹ = Bijectors.bijector(model) |> inverse + q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + obj = ADVI(10) adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) rng = Philox4x(UInt64, seed, 8) q_ref, stats_ref, _ = optimize( - obj, q₀, T; + model, obj, q₀_z, T; optimizer, show_progress = false, rng, @@ -29,11 +30,11 @@ using ReTest λ_ref, _ = Optimisers.destructure(q_ref) @testset "restructure" begin - λ₀, re = Optimisers.destructure(q₀) + λ₀, re = Optimisers.destructure(q₀_z) rng = Philox4x(UInt64, seed, 8) λ, stats, _ = optimize( - obj, re, λ₀, T; + model, obj, re, λ₀, T; optimizer, show_progress = false, rng, @@ -47,13 +48,13 @@ using ReTest rng = Philox4x(UInt64, seed, 8) test_values = rand(rng, T) - callback!(; stat, obj_state, restructure, λ, g) = begin + callback!(; stat, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end rng = Philox4x(UInt64, seed, 8) _, stats, _ = optimize( - obj, q₀, T; + model, obj, q₀_z, T; show_progress = false, rng, adbackend, @@ -69,7 +70,7 @@ using ReTest T_last = T - T_first q_first, _, state = optimize( - obj, q₀, T_first; + model, obj, q₀_z, T_first; optimizer, show_progress = false, rng, @@ -77,7 +78,7 @@ using ReTest ) q, stats, _ = optimize( - obj, q_first, T_last; + model, obj, q_first, T_last; optimizer, show_progress = false, state, From 913b46953fbaaf81150bb308daee6c06d7bfa47d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 00:27:14 -0400 Subject: [PATCH 149/206] fix integer subtype error in documentation of advi Co-authored-by: Tor Erlend Fjelde --- src/objectives/elbo/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 1ce57371..77e1c750 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -19,7 +19,7 @@ This computes the evidence lower-bound (ELBO) through the ADVI formulation: where ``\\phi^{-1}`` is an "inverse bijector." # Arguments -- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.) +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. # Keyword Arguments - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) From 385a653c2d1c37e6fe088c6ecc08b86647f16159 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 00:28:31 -0400 Subject: [PATCH 150/206] fix remove redundant argument for `advi` --- src/objectives/elbo/advi.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index f9a61d81..b70c3299 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -49,7 +49,6 @@ Base.show(io::IO, advi::ADVI) = init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing function (advi::ADVI)( - rng::AbstractRNG, q_η::ContinuousMultivariateDistribution, ηs ::AbstractMatrix ) @@ -81,7 +80,7 @@ function (advi::ADVI)( n_samples::Int = advi.n_samples ) ηs = rand(rng, q_η, n_samples) - advi(rng, q_η, ηs) + advi(q_η, ηs) end function estimate_gradient( @@ -96,7 +95,7 @@ function estimate_gradient( f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) - -advi(rng, q_η, ηs) + -advi(q_η, ηs) end value_and_gradient!(adbackend, f, λ, out) From c9df90e72a842b15c2aa6c41d32ecb14331a7c6b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 00:31:48 -0400 Subject: [PATCH 151/206] remove manifest --- README.md | 1 - test/Manifest.toml | 866 --------------------------------------------- 2 files changed, 867 deletions(-) delete mode 100644 test/Manifest.toml diff --git a/README.md b/README.md index f0bf6cc1..18ba63e5 100644 --- a/README.md +++ b/README.md @@ -248,4 +248,3 @@ end - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015. - Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882. - Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003. - diff --git a/test/Manifest.toml b/test/Manifest.toml deleted file mode 100644 index 220b42bb..00000000 --- a/test/Manifest.toml +++ /dev/null @@ -1,866 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.9.2" -manifest_format = "2.0" -project_hash = "a6495d9f0ea044fd0a55c1c989f1adca1ad5c855" - -[[deps.ADTypes]] -git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a" -uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -version = "0.2.1" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Bijectors]] -deps = ["ArgCheck", "ChainRules", "ChainRulesCore", "ChangesOfVariables", "Compat", "Distributions", "Functors", "InverseFunctions", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "MappedArrays", "Random", "Reexport", "Requires", "Roots", "SparseArrays", "Statistics"] -git-tree-sha1 = "af192c7c235264bdc6f67321fd1c57be0dd7ffb5" -uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.6" - - [deps.Bijectors.extensions] - BijectorsDistributionsADExt = "DistributionsAD" - BijectorsForwardDiffExt = "ForwardDiff" - BijectorsLazyArraysExt = "LazyArrays" - BijectorsReverseDiffExt = "ReverseDiff" - BijectorsTrackerExt = "Tracker" - BijectorsZygoteExt = "Zygote" - - [deps.Bijectors.weakdeps] - DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.53.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" - -[[deps.ChangesOfVariables]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.8" -weakdeps = ["InverseFunctions"] - - [deps.ChangesOfVariables.extensions] - ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" - -[[deps.CommonSolve]] -git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" -uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" -version = "0.2.4" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Comonicon]] -deps = ["Configurations", "ExproniconLite", "Libdl", "Logging", "Markdown", "OrderedCollections", "PackageCompiler", "Pkg", "Scratch", "TOML", "UUIDs"] -git-tree-sha1 = "9c360961f23e2fae4c6549bbba58a6f39c9e145c" -uuid = "863f3e99-da2a-4334-8734-de3dacbe5542" -version = "1.0.5" - -[[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.9.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" - -[[deps.Configurations]] -deps = ["ExproniconLite", "OrderedCollections", "TOML"] -git-tree-sha1 = "434f446dbf89d08350e83bf57c0fc86f5d3ffd4e" -uuid = "5218b696-f38b-4ac9-8b61-a12ec717816d" -version = "0.17.5" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.3" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.100" - - [deps.Distributions.extensions] - DistributionsChainRulesCoreExt = "ChainRulesCore" - DistributionsDensityInterfaceExt = "DensityInterface" - - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" - -[[deps.DistributionsAD]] -deps = ["Adapt", "ChainRules", "ChainRulesCore", "Compat", "Distributions", "FillArrays", "LinearAlgebra", "PDMats", "Random", "Requires", "SpecialFunctions", "StaticArrays", "StatsFuns", "ZygoteRules"] -git-tree-sha1 = "975de103eb2175cf54bf14b15ded2c68625eabdf" -uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.52" - - [deps.DistributionsAD.extensions] - DistributionsADForwardDiffExt = "ForwardDiff" - DistributionsADLazyArraysExt = "LazyArrays" - DistributionsADReverseDiffExt = "ReverseDiff" - DistributionsADTrackerExt = "Tracker" - - [deps.DistributionsAD.weakdeps] - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - -[[deps.Enzyme]] -deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random"] -git-tree-sha1 = "1f85bc8a9da6118abb95d134efc68cf4a6957341" -uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" -version = "0.11.7" - -[[deps.EnzymeCore]] -deps = ["Adapt"] -git-tree-sha1 = "643995502bdfff08bf080212c92430510be01ad5" -uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" -version = "0.5.2" - -[[deps.Enzyme_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "ffa4926cc857bcc5c256825bd7273a6ac989eb34" -uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" -version = "0.0.80+0" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.ExproniconLite]] -deps = ["Pkg", "TOML"] -git-tree-sha1 = "d80b5d5990071086edf5de9018c6c69c83937004" -uuid = "55351af7-c7e9-48d6-89ff-24e801d99491" -version = "0.10.3" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "048dd3d82558759476cff9cff999219216932a08" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.6.0" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.3" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.5" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.8.1" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.5" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.21.4" - -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.10" - -[[deps.InlineTest]] -deps = ["Test"] -git-tree-sha1 = "daf0743879904f0ad645ca6594e1479685f158a2" -uuid = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" -version = "0.2.0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.12" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.8" -weakdeps = ["EnzymeCore"] - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.1.0" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.23+0" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogDensityProblems]] -deps = ["ArgCheck", "DocStringExtensions", "Random"] -git-tree-sha1 = "f9a11237204bc137617194d79d813069838fcf61" -uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -version = "2.1.1" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.26" -weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"] - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.11" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.4" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.ObjectFile]] -deps = ["Reexport", "StructIO"] -git-tree-sha1 = "69607899b46e1f8ead70396bc51a4c361478d8f6" -uuid = "d8793406-e978-5875-9003-1fc021f44a92" -version = "0.4.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.20" - -[[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.17" - -[[deps.PackageCompiler]] -deps = ["Artifacts", "LazyArtifacts", "Libdl", "Pkg", "Printf", "RelocatableFolders", "TOML", "UUIDs"] -git-tree-sha1 = "1a6a868eb755e8ea9ecd000aa6ad175def0cc85b" -uuid = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d" -version = "2.1.7" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.0" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.8.2" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.6.1" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.ReTest]] -deps = ["Distributed", "InlineTest", "Printf", "Random", "Sockets", "Test"] -git-tree-sha1 = "dd8f6587c0abac44bcec2e42f0aeddb73550c0ec" -uuid = "e0db7c4e-2690-44b9-bad6-7687da720f89" -version = "0.3.2" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.RelocatableFolders]] -deps = ["SHA", "Scratch"] -git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" -uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.0" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.ReverseDiff]] -deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] -git-tree-sha1 = "d1235bdd57a93bd7504225b792b867e9a7df38d5" -uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.15.1" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" - -[[deps.Roots]] -deps = ["ChainRulesCore", "CommonSolve", "Printf", "Setfield"] -git-tree-sha1 = "ff42754a57bb0d6dcfe302fd0d4272853190421f" -uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" -version = "2.0.19" - - [deps.Roots.extensions] - RootsForwardDiffExt = "ForwardDiff" - RootsIntervalRootFindingExt = "IntervalRootFinding" - RootsSymPyExt = "SymPy" - - [deps.Roots.weakdeps] - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807" - SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SimpleUnPack]] -git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437" -uuid = "ce78b400-467f-4804-87d8-8f486da07d0a" -version = "1.1.0" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.2" -weakdeps = ["Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.6.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.0" - -[[deps.StatsFuns]] -deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.0" -weakdeps = ["ChainRulesCore", "InverseFunctions"] - - [deps.StatsFuns.extensions] - StatsFunsChainRulesCoreExt = "ChainRulesCore" - StatsFunsInverseFunctionsExt = "InverseFunctions" - -[[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] -git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.15" - -[[deps.StructIO]] -deps = ["Test"] -git-tree-sha1 = "010dc73c7146869c042b49adcdb6bf528c12e859" -uuid = "53d494c1-5632-5724-8f4c-31dff12d585f" -version = "0.3.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.1" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.23" - -[[deps.Tracker]] -deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c" -uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.26" -weakdeps = ["PDMats"] - - [deps.Tracker.extensions] - TrackerPDMatsExt = "PDMats" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.63" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.3" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" From 19d11d141a788ce6f476172f02c004854a0d892d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 00:44:21 -0400 Subject: [PATCH 152/206] refactor remove imports and use fully qualified names --- src/AdvancedVI.jl | 22 ++++++---------------- src/objectives/elbo/advi.jl | 16 ++++++++-------- src/objectives/elbo/entropy.jl | 2 +- src/optimize.jl | 4 ++-- src/utils.jl | 0 5 files changed, 17 insertions(+), 27 deletions(-) delete mode 100644 src/utils.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 91f714e4..1c662cfb 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,11 +4,8 @@ module AdvancedVI using SimpleUnPack: @unpack, @pack! using Accessors -using Random: AbstractRNG, default_rng +using Random using Distributions -import Distributions: - logpdf, _logpdf, rand, rand!, _rand!, - ContinuousMultivariateDistribution using Functors using Optimisers @@ -17,19 +14,16 @@ using DocStringExtensions using ProgressMeter using LinearAlgebra -using LinearAlgebra: AbstractTriangular using LogDensityProblems using ADTypes, DiffResults -using ADTypes: AbstractADType -using ChainRulesCore: @ignore_derivatives +using ChainRulesCore using FillArrays using Bijectors using StatsBase -using StatsBase: entropy # derivatives """ @@ -59,7 +53,7 @@ abstract type AbstractVariationalObjective end """ init( - rng::AbstractRNG, + rng::Random.AbstractRNG, obj::AbstractVariationalObjective, λ::AbstractVector, restructure @@ -73,7 +67,7 @@ This function needs to be implemented only if `obj` is stateful. notice. """ init( - rng::AbstractRNG, + rng::Random.AbstractRNG, obj::AbstractVariationalObjective, λ::AbstractVector, restructure @@ -81,9 +75,9 @@ init( """ estimate_gradient!( - rng ::AbstractRNG, + rng ::Random.AbstractRNG, prob, - adbackend ::AbstractADType, + adbackend ::ADTypes.AbstractADType, obj ::AbstractVariationalObjective, obj_state, λ ::AbstractVector, @@ -114,7 +108,6 @@ include("objectives/elbo/entropy.jl") include("objectives/elbo/advi.jl") export - ELBO, ADVI, ClosedFormEntropy, StickingTheLandingEntropy, @@ -128,9 +121,6 @@ include("optimize.jl") export optimize -include("utils.jl") - - # optional dependencies if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base using Requires diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 77e1c750..97a08b95 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -48,7 +48,7 @@ Base.show(io::IO, advi::ADVI) = function (advi::ADVI)( prob, - q ::ContinuousMultivariateDistribution, + q ::Distributions.ContinuousMultivariateDistribution, zs::AbstractMatrix ) 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs)) @@ -59,7 +59,7 @@ end function (advi::ADVI)( prob, q_trans::Bijectors.TransformedDistribution, - ηs ::AbstractMatrix + ηs ::AbstractMatrix ) @unpack dist, transform = q_trans q = dist @@ -84,8 +84,8 @@ Estimate the ELBO of the variational approximation `q` of the target `prob` usin function (advi::ADVI)( prob, q ::ContinuousMultivariateDistribution; - rng ::AbstractRNG = default_rng(), - n_samples::Int = advi.n_samples + rng ::Random.AbstractRNG = Random.default_rng(), + n_samples::Int = advi.n_samples ) zs = rand(rng, q, n_samples) advi(q, zs) @@ -94,8 +94,8 @@ end function (advi::ADVI)( prob, q_trans ::Bijectors.TransformedDistribution; - rng ::AbstractRNG = default_rng(), - n_samples::Int = advi.n_samples + rng ::Random.AbstractRNG = Random.default_rng(), + n_samples::Int = advi.n_samples ) q = q_trans.dist ηs = rand(rng, q, n_samples) @@ -103,9 +103,9 @@ function (advi::ADVI)( end function estimate_gradient!( - rng ::AbstractRNG, + rng ::Random.AbstractRNG, prob, - adbackend ::AbstractADType, + adbackend ::ADTypes.AbstractADType, advi ::ADVI, est_state, λ ::Vector{<:Real}, diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index e6212c46..63854ec0 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -28,7 +28,7 @@ The "sticking the landing" entropy estimator. struct StickingTheLandingEntropy <: AbstractEntropyEstimator end function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix) - @ignore_derivatives mean(eachcol(ηs)) do ηᵢ + ChainRulesCore.@ignore_derivatives mean(eachcol(ηs)) do ηᵢ -logpdf(q, ηᵢ) end end diff --git a/src/optimize.jl b/src/optimize.jl index 5425d938..85cac75e 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -57,9 +57,9 @@ function optimize( λ₀ ::AbstractVector{<:Real}, n_max_iter ::Int, objargs...; - adbackend ::AbstractADType, + adbackend ::ADTypes.AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), - rng ::AbstractRNG = default_rng(), + rng ::Random.AbstractRNG = Random.default_rng(), show_progress::Bool = true, state ::NamedTuple = NamedTuple(), callback! = nothing, diff --git a/src/utils.jl b/src/utils.jl deleted file mode 100644 index e69de29b..00000000 From 59bd4f848bc034b8428408381e951fead53d6f74 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 00:46:24 -0400 Subject: [PATCH 153/206] update documentation for `AbstractVariationalObjective` Co-authored-by: Tor Erlend Fjelde --- src/AdvancedVI.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1c662cfb..8abf3da9 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -43,7 +43,7 @@ function value_and_gradient! end # estimators """ - abstract type AbstractVariationalObjective end + AbstractVariationalObjective An VI algorithm supported by `AdvancedVI` should implement a subtype of `AbstractVariationalObjective`. Furthermore, it should implement the functions `estimate_gradient`. From dedc5cf1a99bf7771d68380126e480129ed050af Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 01:01:07 -0400 Subject: [PATCH 154/206] refactor use StableRNG instead of Random123 --- test/Project.toml | 2 +- test/advi_locscale.jl | 8 ++++---- test/optimize.jl | 14 +++++++------- test/runtests.jl | 3 +-- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 5ce8fcd8..2c06aa53 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,10 +14,10 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Random123 = "74087812-796a-5b5d-8853-05524746bad3" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 85cfea71..7c60188c 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -21,8 +21,8 @@ using ReTest # :Enzyme => AutoEnzyme(), ) - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) modelstats = modelconstr(realtype; rng) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats @@ -57,7 +57,7 @@ using ReTest end @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) + rng = StableRNG(seed) q, stats, _ = optimize( model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), @@ -68,7 +68,7 @@ using ReTest μ = mean(q.dist) L = sqrt(cov(q.dist)) - rng_repl = Philox4x(UInt64, seed, 8) + rng_repl = StableRNG(seed) q, stats, _ = optimize( model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), diff --git a/test/optimize.jl b/test/optimize.jl index 2af56c1f..c7173f51 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -2,8 +2,8 @@ using ReTest @testset "optimize" begin - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) T = 1000 modelstats = normallognormal_meanfield(Float64; rng) @@ -19,7 +19,7 @@ using ReTest adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) - rng = Philox4x(UInt64, seed, 8) + rng = StableRNG(seed) q_ref, stats_ref, _ = optimize( model, obj, q₀_z, T; optimizer, @@ -32,7 +32,7 @@ using ReTest @testset "restructure" begin λ₀, re = Optimisers.destructure(q₀_z) - rng = Philox4x(UInt64, seed, 8) + rng = StableRNG(seed) λ, stats, _ = optimize( model, obj, re, λ₀, T; optimizer, @@ -45,14 +45,14 @@ using ReTest end @testset "callback" begin - rng = Philox4x(UInt64, seed, 8) + rng = StableRNG(seed) test_values = rand(rng, T) callback!(; stat, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end - rng = Philox4x(UInt64, seed, 8) + rng = StableRNG(seed) _, stats, _ = optimize( model, obj, q₀_z, T; show_progress = false, @@ -64,7 +64,7 @@ using ReTest end @testset "warm start" begin - rng = Philox4x(UInt64, seed, 8) + rng = StableRNG(seed) T_first = div(T,2) T_last = T - T_first diff --git a/test/runtests.jl b/test/runtests.jl index fd68ed79..ef85f16b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,8 +3,7 @@ using ReTest using ReTest: @testset, @test using Comonicon -using Random -using Random123 +using Random, StableRNGs using Statistics using Distributions using LinearAlgebra From e35dc67f24f1d5c60e4a5c5959b0e085f19506a9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 1 Sep 2023 01:08:18 -0400 Subject: [PATCH 155/206] refactor migrate to Test, re-enable x86 tests --- .github/workflows/CI.yml | 2 +- test/Project.toml | 2 +- test/ad.jl | 2 +- test/advi_locscale.jl | 2 +- test/optimize.jl | 2 +- test/runtests.jl | 9 ++------- 6 files changed, 7 insertions(+), 12 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7ba573a1..9731f20c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: - windows-latest arch: - x64 - # - x86 # Uncomment after https://github.com/JuliaTesting/ReTest.jl/pull/52 is merged + - x86 exclude: - os: macOS-latest arch: x86 diff --git a/test/Project.toml b/test/Project.toml index 2c06aa53..0e81ec08 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,10 +14,10 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/ad.jl b/test/ad.jl index f575b485..b716ca2f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,5 +1,5 @@ -using ReTest +using Test @testset "ad" begin @testset "$(adname)" for (adname, adsymbol) ∈ Dict( diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 7c60188c..db2338a3 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -1,7 +1,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false -using ReTest +using Test @testset "advi" begin @testset "locscale" begin diff --git a/test/optimize.jl b/test/optimize.jl index c7173f51..6f7986d0 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -1,5 +1,5 @@ -using ReTest +using Test @testset "optimize" begin seed = (0x38bef07cf9cc549d) diff --git a/test/runtests.jl b/test/runtests.jl index ef85f16b..a4220f98 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ -using ReTest -using ReTest: @testset, @test +using Test +using Test: @testset, @test using Comonicon using Random, StableRNGs @@ -38,8 +38,3 @@ include("models/normallognormal.jl") include("ad.jl") include("advi_locscale.jl") include("optimize.jl") - -@main function runtests(patterns...; dry::Bool = false) - retest(patterns...; dry = dry, verbose = Inf) -end - From 641318331387e3de818cf2c159ae7bc41e313abe Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 5 Sep 2023 18:30:50 +0100 Subject: [PATCH 156/206] refactor remove inner constructor for `ADVI` --- src/objectives/elbo/advi.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 97a08b95..a7d655d3 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -37,12 +37,10 @@ Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective entropy ::EntropyEst n_samples::Int - - function ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) - new{typeof(entropy)}(entropy, n_samples) - end end +ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) = ADVI(entropy, n_samples) + Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") From 1668bae6ee3532fcd751037022f1d6da4dc7257c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 6 Sep 2023 21:56:38 -0400 Subject: [PATCH 157/206] fix swap `export`s and `include`s Co-authored-by: Tor Erlend Fjelde --- src/AdvancedVI.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8abf3da9..d4d776f2 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -103,15 +103,14 @@ function estimate_gradient! end # ADVI-specific interfaces abstract type AbstractEntropyEstimator end -# entropy.jl must preceed advi.jl -include("objectives/elbo/entropy.jl") -include("objectives/elbo/advi.jl") - export ADVI, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy +# entropy.jl must preceed advi.jl +include("objectives/elbo/entropy.jl") +include("objectives/elbo/advi.jl") # Optimization Routine From a8f12541c32cc1c74408ffce84257ce0a4eab526 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 6 Sep 2023 21:57:24 -0400 Subject: [PATCH 158/206] fix doscs for `ADVI` Co-authored-by: Tor Erlend Fjelde --- src/objectives/elbo/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index a7d655d3..5ba0ef34 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -26,7 +26,7 @@ where ``\\phi^{-1}`` is an "inverse bijector." # Requirements - ``q_{\\lambda}`` implements `rand`. -- The target `logdensity(prob)` must be differentiable by the selected AD backend. +- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. From 7b368c12b1dd198ba73607dc565465f1f16c5890 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 6 Sep 2023 21:59:13 -0400 Subject: [PATCH 159/206] fix use `FillArrays` in the test problems Co-authored-by: Tor Erlend Fjelde --- test/advi_locscale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index db2338a3..033736df 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -31,8 +31,8 @@ using Test b = Bijectors.bijector(model) b⁻¹ = inverse(b) - μ₀ = zeros(realtype, n_dims) - L₀ = Diagonal(ones(realtype, n_dims)) + μ₀ = Zeros(realtype, n_dims) + L₀ = Diagonal(Ones(realtype, n_dims)) q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) q₀_z = Bijectors.transformed(q₀_η, b⁻¹) From f216b376f50fdc22e639406eee654d4a0b1922d6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 6 Sep 2023 22:19:39 -0400 Subject: [PATCH 160/206] fix `optimize` docs Co-authored-by: Tor Erlend Fjelde --- src/optimize.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 85cac75e..2d8f57f6 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -37,13 +37,13 @@ Optimize the variational objective `objective` targeting `prob` by estimating (s - `kwargs...`: Additional keywoard arguments. (See below.) # Keyword Arguments -- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.) -- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) -- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) -- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) +- `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. +- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) +- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) +- `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) - `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.) - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) -- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.) +- `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) # Returns - `λ`: Variational parameters optimizing the variational objective. From 9e0338db1601ec02a63ea04470cfbdfaab430d6c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:06:14 -0400 Subject: [PATCH 161/206] fix improve argument names and docs for `optimize` --- src/optimize.jl | 95 ++++++++++++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 2d8f57f6..44617f85 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -5,34 +5,34 @@ end """ optimize( - prob, - objective ::AbstractVariationalObjective, + problem, + objective ::AbstractVariationalObjective, restructure, - λ₀ ::AbstractVector{<:Real}, - n_max_iter ::Int, + param_init ::AbstractVector{<:Real}, + max_iter ::Int, objargs...; kwargs... ) -Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`. +Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` to the function `restructure`. optimize( - prob, - objective ::AbstractVariationalObjective, - q, - n_max_iter::Int, + problem, + objective ::AbstractVariationalObjective, + variational_dist_init, + max_iter ::Int, objargs...; kwargs... ) -Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface. +Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the initial variational approximation `variational_dist_init` supports the `Optimisers.destructure` interface. # Arguments - `objective`: Variational Objective. -- `λ₀`: Initial value of the variational parameters. +- `param_init`: Initial value of the variational parameters. - `restruct`: Function that reconstructs the variational approximation from the flattened parameters. -- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. -- `n_max_iter`: Maximum number of iterations. +- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. +- `max_iter`: Maximum number of iterations. - `objargs...`: Arguments to be passed to `objective`. - `kwargs...`: Additional keywoard arguments. (See below.) @@ -41,47 +41,64 @@ Optimize the variational objective `objective` targeting `prob` by estimating (s - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.) +- `callback!`: Callback function called after every iteration. See further information below. (Default: `nothing`.) - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) - `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) # Returns -- `λ`: Variational parameters optimizing the variational objective. -- `logstats`: Statistics and logs gathered during optimization. -- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. +- `params`: Variational parameters optimizing the variational objective. +- `stats`: Statistics gathered during optimization. +- `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. + +# Callback +The callback function `callback!` has a signature of + + cb(; stat, state, param, restructure, gradient) + +The arguments are as follows: +- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. +- `state`: Collection of the internal states used for optimization. +- `param`: Variational parameters. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. +- `gradient`: The estimated (possibly stochastic) gradient. + +`cb` can return a `NamedTuple` containing some additional information computed within `cb`. +This will be appended to the statistic of the current corresponding iteration. +Otherwise, just return `nothing`. + """ function optimize( - prob, + problem, objective ::AbstractVariationalObjective, restructure, - λ₀ ::AbstractVector{<:Real}, - n_max_iter ::Int, + params_init ::AbstractVector{<:Real}, + max_iter ::Int, objargs...; adbackend ::ADTypes.AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::Random.AbstractRNG = Random.default_rng(), show_progress::Bool = true, - state ::NamedTuple = NamedTuple(), + state_init ::NamedTuple = NamedTuple(), callback! = nothing, prog = ProgressMeter.Progress( - n_max_iter; + max_iter; desc = "Optimizing", barlen = 31, showspeed = true, enabled = show_progress ) ) - λ = copy(λ₀) - opt_st = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ) - obj_st = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure) + λ = copy(params_init) + opt_st = haskey(state_init, :opt) ? state_init.opt : Optimisers.setup(optimizer, λ) + obj_st = haskey(state_init, :obj) ? state_init.obj : init(rng, objective, λ, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) - logstats = NamedTuple[] + stats = NamedTuple[] - for t = 1:n_max_iter + for t = 1:max_iter stat = (iteration=t,) grad_buf, obj_st, stat′ = estimate_gradient!( - rng, prob, adbackend, objective, obj_st, + rng, problem, adbackend, objective, obj_st, λ, restructure, grad_buf, objargs... ) stat = merge(stat, stat′) @@ -90,29 +107,33 @@ function optimize( opt_st, λ = Optimisers.update!(opt_st, λ, g) if !isnothing(callback!) - stat′ = callback!(; stat, restructure, λ, g) + stat′ = callback!( + ; stat, restructure, params=λ, gradient=g, + state=(optimizer=opt_st, objective=obj_st) + ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end @debug "Iteration $t" stat... pm_next!(prog, stat) - push!(logstats, stat) + push!(stats, stat) end - state = (opt=opt_st, obj=obj_st) - logstats = map(identity, logstats) - λ, logstats, state + state = (optimizer=opt_st, objective=obj_st) + stats = map(identity, stats) + params = λ + params, stats, state end -function optimize(prob, +function optimize(problem, objective ::AbstractVariationalObjective, - q₀, + variational_dist_init, n_max_iter::Int, objargs...; kwargs...) - λ, restructure = Optimisers.destructure(q₀) + λ, restructure = Optimisers.destructure(variational_dist_init) λ, logstats, state = optimize( - prob, objective, restructure, λ, n_max_iter, objargs...; kwargs... + problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... ) restructure(λ), logstats, state end From d6fcaf6ec86b9d2d06e9e26506d97cbf316661e2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:19:11 -0400 Subject: [PATCH 162/206] fix tests to match new interface of `optimize` --- test/optimize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/optimize.jl b/test/optimize.jl index 6f7986d0..21718f52 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -48,7 +48,7 @@ using Test rng = StableRNG(seed) test_values = rand(rng, T) - callback!(; stat, restructure, λ, g) = begin + callback!(; stat, args...) = begin (test_value = test_values[stat.iteration],) end @@ -81,7 +81,7 @@ using Test model, obj, q_first, T_last; optimizer, show_progress = false, - state, + state_init = state, rng, adbackend ) From 5799f1e246a9c1edd86a307d0644ba09369e9ae3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:19:33 -0400 Subject: [PATCH 163/206] refactor move utility functions to new file --- src/AdvancedVI.jl | 13 ++++++++----- src/optimize.jl | 8 ++------ src/utils.jl | 23 +++++++++++++++++++++++ 3 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 src/utils.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d4d776f2..35b493c8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -67,10 +67,10 @@ This function needs to be implemented only if `obj` is stateful. notice. """ init( - rng::Random.AbstractRNG, - obj::AbstractVariationalObjective, - λ::AbstractVector, - restructure + ::Random.AbstractRNG, + ::AbstractVariationalObjective, + ::AbstractVector, + ::Any ) = nothing """ @@ -108,6 +108,7 @@ export ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy + # entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") include("objectives/elbo/advi.jl") @@ -116,9 +117,11 @@ include("objectives/elbo/advi.jl") function optimize end +export optimize + +include("utils.jl") include("optimize.jl") -export optimize # optional dependencies if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base diff --git a/src/optimize.jl b/src/optimize.jl index 44617f85..5f257c42 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -1,8 +1,4 @@ -function pm_next!(pm, stats::NamedTuple) - ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) -end - """ optimize( problem, @@ -89,8 +85,8 @@ function optimize( ) ) λ = copy(params_init) - opt_st = haskey(state_init, :opt) ? state_init.opt : Optimisers.setup(optimizer, λ) - obj_st = haskey(state_init, :obj) ? state_init.obj : init(rng, objective, λ, restructure) + opt_st = maybe_init_optimizer(state_init, optimizer, λ) + obj_st = maybe_init_objective(state_init, rng, objective, λ, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..ce11d0be --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,23 @@ + +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end + +function maybe_init_optimizer( + state_init::Union{Nothing, NamedTuple}, + optimizer ::Optimisers.AbstractRule, + λ ::AbstractVector +) + haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, λ) +end + +function maybe_init_objective( + state_init::Union{Nothing, NamedTuple}, + rng ::Random.AbstractRNG, + objective ::AbstractVariationalObjective, + λ ::AbstractVector, + restructure +) + haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure) +end + From 2229d61d229e130c557f2797ed8ac7affffec193 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:21:21 -0400 Subject: [PATCH 164/206] fix docs for `optimize` Co-authored-by: Tor Erlend Fjelde --- src/optimize.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 5f257c42..97146319 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -30,7 +30,6 @@ Optimize the variational objective `objective` targeting the problem `problem` b - `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. - `max_iter`: Maximum number of iterations. - `objargs...`: Arguments to be passed to `objective`. -- `kwargs...`: Additional keywoard arguments. (See below.) # Keyword Arguments - `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. From bc48e14b31eb8c35387f5dbadf21e4d519689c4c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:36:09 -0400 Subject: [PATCH 165/206] refactor advi internal objective Co-authored-by: Tor Erlend Fjelde --- src/objectives/elbo/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 5ba0ef34..37b53a47 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -110,7 +110,7 @@ function estimate_gradient!( restructure, out ::DiffResults.MutableDiffResult ) - f(λ′) = begin + function f(λ′) q_trans = restructure(λ′) q = q_trans.dist ηs = rand(rng, q, advi.n_samples) From 9949a04bff27a1ba69523a8e8eb96bcbeadea09c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:38:25 -0400 Subject: [PATCH 166/206] refactor move `rng` to be an optional first argument --- src/optimize.jl | 74 ++++++++++++++++++++++++++++++++++++++++--- test/advi_locscale.jl | 9 ++---- test/optimize.jl | 32 +++++++++++++------ test/runtests.jl | 2 +- 4 files changed, 95 insertions(+), 22 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 97146319..47fd102f 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -62,7 +62,30 @@ This will be appended to the statistic of the current corresponding iteration. Otherwise, just return `nothing`. """ + function optimize( + problem, + objective ::AbstractVariationalObjective, + restructure, + params_init ::AbstractVector{<:Real}, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + restructure, + params_init, + max_iter, + objargs...; + kwargs... + ) +end + +function optimize( + rng ::Random.AbstractRNG, problem, objective ::AbstractVariationalObjective, restructure, @@ -71,7 +94,6 @@ function optimize( objargs...; adbackend ::ADTypes.AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), - rng ::Random.AbstractRNG = Random.default_rng(), show_progress::Bool = true, state_init ::NamedTuple = NamedTuple(), callback! = nothing, @@ -120,15 +142,57 @@ function optimize( params, stats, state end -function optimize(problem, - objective ::AbstractVariationalObjective, +function optimize( + problem, + objective ::AbstractVariationalObjective, + restructure, + params_init ::AbstractVector{<:Real}, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + restructure, + params_init, + max_iter, + objargs...; + kwargs... + ) +end + +function optimize(rng ::Random.AbstractRNG, + problem, + objective ::AbstractVariationalObjective, variational_dist_init, - n_max_iter::Int, + n_max_iter ::Int, objargs...; kwargs...) λ, restructure = Optimisers.destructure(variational_dist_init) λ, logstats, state = optimize( - problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... + rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... ) restructure(λ), logstats, state end + + +function optimize( + problem, + objective ::AbstractVariationalObjective, + variational_dist_init, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + variational_dist_init, + max_iter, + objargs...; + kwargs... + ) +end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 033736df..dab8f560 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -40,10 +40,9 @@ using Test @testset "convergence" begin Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats, _ = optimize( - model, objective, q₀_z, T; + rng, model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, - rng = rng, adbackend = adbackend, ) @@ -59,10 +58,9 @@ using Test @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - model, objective, q₀_z, T; + rng, model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, - rng = rng, adbackend = adbackend, ) μ = mean(q.dist) @@ -70,10 +68,9 @@ using Test rng_repl = StableRNG(seed) q, stats, _ = optimize( - model, objective, q₀_z, T; + rng_repl, model, objective, q₀_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, - rng = rng_repl, adbackend = adbackend, ) μ_repl = mean(q.dist) diff --git a/test/optimize.jl b/test/optimize.jl index 21718f52..3de3cdc3 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -21,23 +21,38 @@ using Test rng = StableRNG(seed) q_ref, stats_ref, _ = optimize( - model, obj, q₀_z, T; + rng, model, obj, q₀_z, T; optimizer, show_progress = false, - rng, adbackend, ) λ_ref, _ = Optimisers.destructure(q_ref) + @testset "default_rng" begin + optimize( + model, obj, q₀_z, T; + optimizer, + show_progress = false, + adbackend, + ) + + λ₀, re = Optimisers.destructure(q₀_z) + optimize( + model, obj, re, λ₀, T; + optimizer, + show_progress = false, + adbackend, + ) + end + @testset "restructure" begin λ₀, re = Optimisers.destructure(q₀_z) rng = StableRNG(seed) λ, stats, _ = optimize( - model, obj, re, λ₀, T; + rng, model, obj, re, λ₀, T; optimizer, show_progress = false, - rng, adbackend, ) @test λ == λ_ref @@ -54,9 +69,8 @@ using Test rng = StableRNG(seed) _, stats, _ = optimize( - model, obj, q₀_z, T; + rng, model, obj, q₀_z, T; show_progress = false, - rng, adbackend, callback! ) @@ -70,19 +84,17 @@ using Test T_last = T - T_first q_first, _, state = optimize( - model, obj, q₀_z, T_first; + rng, model, obj, q₀_z, T_first; optimizer, show_progress = false, - rng, adbackend ) q, stats, _ = optimize( - model, obj, q_first, T_last; + rng, model, obj, q_first, T_last; optimizer, show_progress = false, state_init = state, - rng, adbackend ) @test q == q_ref diff --git a/test/runtests.jl b/test/runtests.jl index a4220f98..5f8fab41 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,5 +36,5 @@ include("models/normallognormal.jl") # Tests include("ad.jl") -include("advi_locscale.jl") include("optimize.jl") +include("advi_locscale.jl") From 92cf3547da9c58a6da3d97065dd6993669845f61 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 00:49:01 -0400 Subject: [PATCH 167/206] fix docs for optimize --- src/optimize.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 47fd102f..5beef712 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -9,9 +9,6 @@ objargs...; kwargs... ) - -Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` to the function `restructure`. - optimize( problem, objective ::AbstractVariationalObjective, @@ -21,7 +18,7 @@ Optimize the variational objective `objective` targeting the problem `problem` b kwargs... ) -Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the initial variational approximation `variational_dist_init` supports the `Optimisers.destructure` interface. +Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. # Arguments - `objective`: Variational Objective. From d75fd3cf96c3e708757708c94674654da77a81d1 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 22:23:06 -0400 Subject: [PATCH 168/206] add compat bounds to test dependencies --- test/Project.toml | 21 ++++++++++++++++++++- test/runtests.jl | 1 - 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 0e81ec08..56aa3dff 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -21,3 +20,23 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +ADTypes = "0.2.1" +Bijectors = "0.13.6" +Distributions = "0.25.100" +DistributionsAD = "0.6.45" +Enzyme = "0.11.7" +FillArrays = "1.6.1" +ForwardDiff = "0.10.36" +Functors = "0.4.5" +LogDensityProblems = "2.1.1" +Optimisers = "0.3.0" +PDMats = "0.11.7" +Pkg = "1.9.2" +ReverseDiff = "1.15.1" +SimpleUnPack = "1.1.0" +StableRNGs = "1.0.0" +Tracker = "0.2.20" +Zygote = "0.6.63" +julia = "1.6" diff --git a/test/runtests.jl b/test/runtests.jl index 5f8fab41..9b48bc37 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,6 @@ using Test using Test: @testset, @test -using Comonicon using Random, StableRNGs using Statistics using Distributions From faa91ce33dbb48124dafb76bd46253fc72fe00b2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 22:25:37 -0400 Subject: [PATCH 169/206] update compat bound for `Optimisers` --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 075ae92f..658c168c 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ FillArrays = "1.3" ForwardDiff = "0.10.36" Functors = "0.4" LogDensityProblems = "2" -Optimisers = "0.2.16" +Optimisers = "0.2.16, 0.3" ProgressMeter = "1.6" Requires = "1.0" ReverseDiff = "1.15.1" diff --git a/test/Project.toml b/test/Project.toml index 56aa3dff..e61061db 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" LogDensityProblems = "2.1.1" -Optimisers = "0.3.0" +Optimisers = "0.2.16, 0.3" PDMats = "0.11.7" Pkg = "1.9.2" ReverseDiff = "1.15.1" From 6dc0bb745dea89cc40228e6a7d13a315dbbf2e33 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 7 Sep 2023 22:32:27 -0400 Subject: [PATCH 170/206] fix test compat --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index e61061db..89c4f77b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -33,7 +33,6 @@ Functors = "0.4.5" LogDensityProblems = "2.1.1" Optimisers = "0.2.16, 0.3" PDMats = "0.11.7" -Pkg = "1.9.2" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" From e941ad4b922e40c170fd6d6ed25c026cfc093cc7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:23:08 -0400 Subject: [PATCH 171/206] fix remove `!` in callback Co-authored-by: Tor Erlend Fjelde --- src/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 5beef712..72db40e7 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -45,7 +45,7 @@ Optimize the variational objective `objective` targeting the problem `problem` b # Callback The callback function `callback!` has a signature of - cb(; stat, state, param, restructure, gradient) + callback!(; stat, state, param, restructure, gradient) The arguments are as follows: - `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. From 15e05534f102430e09b781e7fad916bd84254bd8 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:38:44 -0400 Subject: [PATCH 172/206] fix rng argument position in `advi` --- src/objectives/elbo/advi.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 37b53a47..a14fced5 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -72,34 +72,35 @@ end """ (advi::ADVI)( - prob, q; - rng::AbstractRNG = Random.default_rng(), - n_samples::Int = advi.n_samples + [rng], prob, q; n_samples::Int = advi.n_samples ) Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples. """ function (advi::ADVI)( + rng ::Random.AbstractRNG, prob, - q ::ContinuousMultivariateDistribution; - rng ::Random.AbstractRNG = Random.default_rng(), - n_samples::Int = advi.n_samples + q ::ContinuousDistribution; + n_samples::Int = advi.n_samples ) zs = rand(rng, q, n_samples) - advi(q, zs) + advi(prob, q, zs) end function (advi::ADVI)( + rng ::Random.AbstractRNG, prob, q_trans ::Bijectors.TransformedDistribution; - rng ::Random.AbstractRNG = Random.default_rng(), - n_samples::Int = advi.n_samples + n_samples::Int = advi.n_samples ) q = q_trans.dist ηs = rand(rng, q, n_samples) - advi(q_trans, ηs) + advi(prob, q_trans, ηs) end +(advi::ADVI)(prob, q::Distribution; n_samples::Int = advi.n_samples) = + advi(Random.default_rng(), prob, q; n_samples) + function estimate_gradient!( rng ::Random.AbstractRNG, prob, From a643cf28cedbb719c5f0276beecf9676723b45a5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:39:30 -0400 Subject: [PATCH 173/206] fix callback signature in `optimize` --- src/optimize.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 72db40e7..17f31689 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -43,9 +43,9 @@ Optimize the variational objective `objective` targeting the problem `problem` b - `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. # Callback -The callback function `callback!` has a signature of +The callback function `callback` has a signature of - callback!(; stat, state, param, restructure, gradient) + callback(; stat, state, param, restructure, gradient) The arguments are as follows: - `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. @@ -93,7 +93,7 @@ function optimize( optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), show_progress::Bool = true, state_init ::NamedTuple = NamedTuple(), - callback! = nothing, + callback = nothing, prog = ProgressMeter.Progress( max_iter; desc = "Optimizing", @@ -120,8 +120,8 @@ function optimize( g = DiffResults.gradient(grad_buf) opt_st, λ = Optimisers.update!(opt_st, λ, g) - if !isnothing(callback!) - stat′ = callback!( + if !isnothing(callback) + stat′ = callback( ; stat, restructure, params=λ, gradient=g, state=(optimizer=opt_st, objective=obj_st) ) From ffa69a33adb08678c7618d8e25c0402d6b10a649 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:39:55 -0400 Subject: [PATCH 174/206] refactor reorganize test files and naming --- .../advi_distributionsad.jl} | 0 test/{ => interface}/ad.jl | 0 test/interface/advi.jl | 55 +++++++++++++++++++ test/{ => interface}/optimize.jl | 4 +- test/models/normal.jl | 43 +++++++++++++++ test/runtests.jl | 9 ++- 6 files changed, 106 insertions(+), 5 deletions(-) rename test/{advi_locscale.jl => inference/advi_distributionsad.jl} (100%) rename test/{ => interface}/ad.jl (100%) create mode 100644 test/interface/advi.jl rename test/{ => interface}/optimize.jl (97%) create mode 100644 test/models/normal.jl diff --git a/test/advi_locscale.jl b/test/inference/advi_distributionsad.jl similarity index 100% rename from test/advi_locscale.jl rename to test/inference/advi_distributionsad.jl diff --git a/test/ad.jl b/test/interface/ad.jl similarity index 100% rename from test/ad.jl rename to test/interface/ad.jl diff --git a/test/interface/advi.jl b/test/interface/advi.jl new file mode 100644 index 00000000..904305ac --- /dev/null +++ b/test/interface/advi.jl @@ -0,0 +1,55 @@ + +using Test + +@testset "advi" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + @testset "with bijector" begin + modelstats = normallognormal_meanfield(Float64; rng) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + b⁻¹ = Bijectors.bijector(model) |> inverse + q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + obj = ADVI(10) + + rng = StableRNG(seed) + elbo_ref = obj(rng, model, q₀_z; n_samples=1024) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = obj(rng, model, q₀_z; n_samples=1024) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = obj(model, q₀_z; n_samples=1024) + @test elbo ≈ elbo_ref rtol=0.1 + end + end + + @testset "without bijector" begin + modelstats = normal_meanfield(Float64; rng) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + q₀_z = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + + obj = ADVI(10) + rng = StableRNG(seed) + elbo_ref = obj(rng, model, q₀_z; n_samples=1024) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = obj(rng, model, q₀_z; n_samples=1024) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = obj(model, q₀_z; n_samples=1024) + @test elbo ≈ elbo_ref rtol=0.1 + end + end +end diff --git a/test/optimize.jl b/test/interface/optimize.jl similarity index 97% rename from test/optimize.jl rename to test/interface/optimize.jl index 3de3cdc3..bbbb4998 100644 --- a/test/optimize.jl +++ b/test/interface/optimize.jl @@ -63,7 +63,7 @@ using Test rng = StableRNG(seed) test_values = rand(rng, T) - callback!(; stat, args...) = begin + callback(; stat, args...) = begin (test_value = test_values[stat.iteration],) end @@ -72,7 +72,7 @@ using Test rng, model, obj, q₀_z, T; show_progress = false, adbackend, - callback! + callback ) @test [stat.test_value for stat ∈ stats] == test_values end diff --git a/test/models/normal.jl b/test/models/normal.jl new file mode 100644 index 00000000..3efa7524 --- /dev/null +++ b/test/models/normal.jl @@ -0,0 +1,43 @@ + +struct TestNormal{M,S} + μ::M + Σ::S +end + +function LogDensityProblems.logdensity(model::TestNormal, θ) + @unpack μ, Σ = model + logpdf(MvNormal(μ, Σ), θ) +end + +function LogDensityProblems.dimension(model::TestNormal) + length(model.μ) +end + +function LogDensityProblems.capabilities(::Type{<:TestNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function normal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + L = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ = L*L' |> Hermitian + + model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) + + TestModel(model, μ, L, n_dims, false) +end + +function normal_meanfield(realtype; rng = default_rng()) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = TestNormal(μ, Diagonal(σ.^2)) + + L = σ |> Diagonal + + TestModel(model, μ, L, n_dims, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 9b48bc37..6fda0be8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,8 +32,11 @@ struct TestModel{M,L,S} end include("models/normallognormal.jl") +include("models/normal.jl") # Tests -include("ad.jl") -include("optimize.jl") -include("advi_locscale.jl") +include("interface/ad.jl") +include("interface/optimize.jl") +include("interface/advi.jl") + +include("inference/advi_distributionsad.jl") From d5026e14aa441c32b12fa7050c583ec0011d0063 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:41:19 -0400 Subject: [PATCH 175/206] fix simplify description for `optimize` Co-authored-by: Tor Erlend Fjelde --- src/optimize.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 17f31689..7e62e255 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -18,7 +18,9 @@ kwargs... ) -Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. +Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. + +The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. # Arguments - `objective`: Variational Objective. From 764406b2a33687c83efb6969f9aeb16be735db81 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:51:02 -0400 Subject: [PATCH 176/206] fix remove redundant `Nothing` type signature for `maybe_init` --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ce11d0be..8dd7c37b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,7 +4,7 @@ function pm_next!(pm, stats::NamedTuple) end function maybe_init_optimizer( - state_init::Union{Nothing, NamedTuple}, + state_init::NamedTuple, optimizer ::Optimisers.AbstractRule, λ ::AbstractVector ) @@ -12,7 +12,7 @@ function maybe_init_optimizer( end function maybe_init_objective( - state_init::Union{Nothing, NamedTuple}, + state_init::NamedTuple, rng ::Random.AbstractRNG, objective ::AbstractVariationalObjective, λ ::AbstractVector, From 65006cb23cf258d0dd7acfa41d025a6c9aead4f4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:51:41 -0400 Subject: [PATCH 177/206] fix remove "internal use" warning in documentation --- src/AdvancedVI.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 35b493c8..203f5ae1 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -61,10 +61,6 @@ abstract type AbstractVariationalObjective end Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. - -!!! warning - This is an internal function. Thus, the signature is subject to change without - notice. """ init( ::Random.AbstractRNG, @@ -93,10 +89,6 @@ If the objective is stateful, `obj_state` is its previous state, otherwise, it i - `out`: The `MutableDiffResult` containing the objective value and gradient estimates. - `obj_state`: The updated state of the objective estimator. - `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`) - -!!! warning - This is an internal function. Thus, the signature is subject to change without - notice. """ function estimate_gradient! end From b23a610f1fdf3865827522d7c7b44a2badace2cd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:51:59 -0400 Subject: [PATCH 178/206] refactor change `estimate_gradient!` signature to be type stable --- src/objectives/elbo/advi.jl | 12 ++++++------ src/optimize.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index a14fced5..7640f486 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -102,14 +102,14 @@ end advi(Random.default_rng(), prob, q; n_samples) function estimate_gradient!( - rng ::Random.AbstractRNG, + rng ::Random.AbstractRNG, + advi ::ADVI, + adbackend ::ADTypes.AbstractADType, + out ::DiffResults.MutableDiffResult, prob, - adbackend ::ADTypes.AbstractADType, - advi ::ADVI, - est_state, - λ ::Vector{<:Real}, + λ, restructure, - out ::DiffResults.MutableDiffResult + est_state, ) function f(λ′) q_trans = restructure(λ′) diff --git a/src/optimize.jl b/src/optimize.jl index 17f31689..cfe6179e 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -112,8 +112,8 @@ function optimize( stat = (iteration=t,) grad_buf, obj_st, stat′ = estimate_gradient!( - rng, problem, adbackend, objective, obj_st, - λ, restructure, grad_buf, objargs... + rng, objective, adbackend, grad_buf, problem, + λ, restructure, obj_st, objargs... ) stat = merge(stat, stat′) From 9c242a53dbd1f0b6f32be200695f94e6dd8b4201 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 00:57:10 -0400 Subject: [PATCH 179/206] add signature for computing `advi` over a fixed set of samples --- src/objectives/elbo/advi.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 7640f486..90e84b56 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -44,6 +44,13 @@ ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) = Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") +""" + (advi::ADVI)( + [rng], prob, q, zs::AbstractMatrix + ) + +Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample). +""" function (advi::ADVI)( prob, q ::Distributions.ContinuousMultivariateDistribution, From e0148637882b370717e128ce419508b1c1d88dd3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 01:10:44 -0400 Subject: [PATCH 180/206] fix change test tolerance --- test/interface/advi.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/interface/advi.jl b/test/interface/advi.jl index 904305ac..409a0ca6 100644 --- a/test/interface/advi.jl +++ b/test/interface/advi.jl @@ -16,16 +16,16 @@ using Test obj = ADVI(10) rng = StableRNG(seed) - elbo_ref = obj(rng, model, q₀_z; n_samples=1024) + elbo_ref = obj(rng, model, q₀_z; n_samples=10^4) @testset "determinism" begin rng = StableRNG(seed) - elbo = obj(rng, model, q₀_z; n_samples=1024) + elbo = obj(rng, model, q₀_z; n_samples=10^4) @test elbo == elbo_ref end @testset "default_rng" begin - elbo = obj(model, q₀_z; n_samples=1024) + elbo = obj(model, q₀_z; n_samples=10^4) @test elbo ≈ elbo_ref rtol=0.1 end end @@ -39,16 +39,16 @@ using Test obj = ADVI(10) rng = StableRNG(seed) - elbo_ref = obj(rng, model, q₀_z; n_samples=1024) + elbo_ref = obj(rng, model, q₀_z; n_samples=10^4) @testset "determinism" begin rng = StableRNG(seed) - elbo = obj(rng, model, q₀_z; n_samples=1024) + elbo = obj(rng, model, q₀_z; n_samples=10^4) @test elbo == elbo_ref end @testset "default_rng" begin - elbo = obj(model, q₀_z; n_samples=1024) + elbo = obj(model, q₀_z; n_samples=10^4) @test elbo ≈ elbo_ref rtol=0.1 end end From 71184fa4af2b46e02649fd939052fa2701c7c862 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 01:20:43 -0400 Subject: [PATCH 181/206] fix update documentation for `estimate_gradient!` --- src/AdvancedVI.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 203f5ae1..dd7f10ae 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -72,13 +72,13 @@ init( """ estimate_gradient!( rng ::Random.AbstractRNG, - prob, - adbackend ::ADTypes.AbstractADType, obj ::AbstractVariationalObjective, - obj_state, - λ ::AbstractVector, - restructure, + adbackend ::ADTypes.AbstractADType, out ::DiffResults.MutableDiffResult + prob, + λ, + restructure, + obj_state, ) Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. From 9f6d6634e9a91c0129eb8642af27cd0c544b74c9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 01:20:55 -0400 Subject: [PATCH 182/206] refactor remove type constraint for variational parameters --- src/optimize.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index df09355d..0868044f 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -4,7 +4,7 @@ problem, objective ::AbstractVariationalObjective, restructure, - param_init ::AbstractVector{<:Real}, + param_init, max_iter ::Int, objargs...; kwargs... @@ -66,7 +66,7 @@ function optimize( problem, objective ::AbstractVariationalObjective, restructure, - params_init ::AbstractVector{<:Real}, + params_init, max_iter ::Int, objargs...; kwargs... @@ -88,7 +88,7 @@ function optimize( problem, objective ::AbstractVariationalObjective, restructure, - params_init ::AbstractVector{<:Real}, + params_init, max_iter ::Int, objargs...; adbackend ::ADTypes.AbstractADType, @@ -145,7 +145,7 @@ function optimize( problem, objective ::AbstractVariationalObjective, restructure, - params_init ::AbstractVector{<:Real}, + params_init, max_iter ::Int, objargs...; kwargs... From a673520f7510660e632ad95ffbe1a6f574c10bef Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 01:22:45 -0400 Subject: [PATCH 183/206] fix remove dead code --- src/optimize.jl | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 0868044f..208ffabe 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -62,27 +62,6 @@ Otherwise, just return `nothing`. """ -function optimize( - problem, - objective ::AbstractVariationalObjective, - restructure, - params_init, - max_iter ::Int, - objargs...; - kwargs... -) - optimize( - Random.default_rng(), - problem, - objective, - restructure, - params_init, - max_iter, - objargs...; - kwargs... - ) -end - function optimize( rng ::Random.AbstractRNG, problem, From a3f98867545c374c51b2d2956ded59384e05e57a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 01:27:48 -0400 Subject: [PATCH 184/206] add compat entry for stdlib --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 658c168c..70041561 100644 --- a/Project.toml +++ b/Project.toml @@ -45,9 +45,11 @@ Enzyme = "0.11.7" FillArrays = "1.3" ForwardDiff = "0.10.36" Functors = "0.4" +LinearAlgebra = "1" LogDensityProblems = "2" Optimisers = "0.2.16, 0.3" ProgressMeter = "1.6" +Random = "1" Requires = "1.0" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" From 7a92708950cf5a69957863b7925dbd3a49b294b3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 01:31:31 -0400 Subject: [PATCH 185/206] add compat entry for stdlib in `test/` --- test/Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 89c4f77b..eb4d9d59 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,12 +30,16 @@ Enzyme = "0.11.7" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" +LinearAlgebra = "1" LogDensityProblems = "2.1.1" Optimisers = "0.2.16, 0.3" PDMats = "0.11.7" +Random = "1" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" +Statistics = "1.9" +Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" julia = "1.6" From 5dd434d9dcffb9c7b96b2d5bfa8f564c42ba4e42 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 23:26:59 -0400 Subject: [PATCH 186/206] fix rng argument position in tests --- test/inference/advi_distributionsad.jl | 2 +- test/interface/advi.jl | 4 ++-- test/interface/optimize.jl | 2 +- test/models/normal.jl | 4 ++-- test/models/normallognormal.jl | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl index dab8f560..2d60a16a 100644 --- a/test/inference/advi_distributionsad.jl +++ b/test/inference/advi_distributionsad.jl @@ -24,7 +24,7 @@ using Test seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) - modelstats = modelconstr(realtype; rng) + modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) diff --git a/test/interface/advi.jl b/test/interface/advi.jl index 409a0ca6..16db09ca 100644 --- a/test/interface/advi.jl +++ b/test/interface/advi.jl @@ -6,7 +6,7 @@ using Test rng = StableRNG(seed) @testset "with bijector" begin - modelstats = normallognormal_meanfield(Float64; rng) + modelstats = normallognormal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats @@ -31,7 +31,7 @@ using Test end @testset "without bijector" begin - modelstats = normal_meanfield(Float64; rng) + modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index bbbb4998..1384a4b4 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -6,7 +6,7 @@ using Test rng = StableRNG(seed) T = 1000 - modelstats = normallognormal_meanfield(Float64; rng) + modelstats = normallognormal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats diff --git a/test/models/normal.jl b/test/models/normal.jl index 3efa7524..3f305e1a 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -17,7 +17,7 @@ function LogDensityProblems.capabilities(::Type{<:TestNormal}) LogDensityProblems.LogDensityOrder{0}() end -function normal_fullrank(realtype; rng = default_rng()) +function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 μ = randn(rng, realtype, n_dims) @@ -29,7 +29,7 @@ function normal_fullrank(realtype; rng = default_rng()) TestModel(model, μ, L, n_dims, false) end -function normal_meanfield(realtype; rng = default_rng()) +function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 μ = randn(rng, realtype, n_dims) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index e2b9e816..c2cb2b0e 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -26,7 +26,7 @@ function Bijectors.bijector(model::NormalLogNormal) [1:1, 2:1+length(μ_y)]) end -function normallognormal_fullrank(realtype; rng = default_rng()) +function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 μ_x = randn(rng, realtype) @@ -43,12 +43,12 @@ function normallognormal_fullrank(realtype; rng = default_rng()) Σ = Σ |> Hermitian μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L |> LowerTriangular + L = cholesky(Σ).L TestModel(model, μ, L, n_dims+1, false) end -function normallognormal_meanfield(realtype; rng = default_rng()) +function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 μ_x = randn(rng, realtype) From a764d9baf6826e519faeb5762f2b7195437d16ec Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 23:30:59 -0400 Subject: [PATCH 187/206] refactor change name of inference test --- test/inference/advi_distributionsad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl index 2d60a16a..01c7a96e 100644 --- a/test/inference/advi_distributionsad.jl +++ b/test/inference/advi_distributionsad.jl @@ -3,8 +3,8 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "advi" begin - @testset "locscale" begin +@testset "inference_advi" begin + @testset "distributionsad" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64], # Currently only tested against Float64 (modelname, modelconstr) ∈ Dict( From 8af8a5f6ede44600156398546f58097b10b8bdd5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 23 Oct 2023 23:48:20 -0400 Subject: [PATCH 188/206] fix documentation for `optimize` --- src/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 208ffabe..1a39cb41 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,7 +35,7 @@ The variational approximation can be constructed by passing the variational para - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) -- `callback!`: Callback function called after every iteration. See further information below. (Default: `nothing`.) +- `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) - `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) From 5f1fb52b5be0c46ea295087f9f3644396f239c66 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Oct 2023 00:06:50 -0400 Subject: [PATCH 189/206] refactor rewrite the documentation for the global interfaces --- src/AdvancedVI.jl | 73 ++++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index dd7f10ae..54c2b1eb 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -27,17 +27,15 @@ using StatsBase # derivatives """ - value_and_gradient!( - ad::ADTypes.AbstractADType, - f, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult - ) - -Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. -The result is stored in `out`. -The function `f` must return a scalar value. -The gradient is stored in `out` as a vector of the same length as `θ`. + value_and_gradient!(ad, f, θ, out) + +Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. + +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `θ`: The point to evaluate the gradient. +- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function value_and_gradient! end @@ -45,22 +43,26 @@ function value_and_gradient! end """ AbstractVariationalObjective -An VI algorithm supported by `AdvancedVI` should implement a subtype of `AbstractVariationalObjective`. -Furthermore, it should implement the functions `estimate_gradient`. +Abstract type for the VI algorithms supported by `AdvancedVI`. + +# Implementations +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`. +Also, it should provide gradients by implementing the function `estimate_gradient!`. If the estimator is stateful, it can implement `init` to initialize the state. """ abstract type AbstractVariationalObjective end """ - init( - rng::Random.AbstractRNG, - obj::AbstractVariationalObjective, - λ::AbstractVector, - restructure - ) + init(rng, obj, λ, restructure) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `λ`: Initial variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `λ`. """ init( ::Random.AbstractRNG, @@ -70,25 +72,24 @@ init( ) = nothing """ - estimate_gradient!( - rng ::Random.AbstractRNG, - obj ::AbstractVariationalObjective, - adbackend ::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult - prob, - λ, - restructure, - obj_state, - ) - -Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. -The estimated objective value and gradient are then stored in `out`. -If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`. + estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state) + +Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `λ`: Variational parameters to evaluate the gradient on. +- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `obj_state`: Previous state of the objective. # Returns -- `out`: The `MutableDiffResult` containing the objective value and gradient estimates. -- `obj_state`: The updated state of the objective estimator. -- `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`) +- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective. +- `stat::NamedTuple`: Statistics and logs generated during estimation. """ function estimate_gradient! end From 2491c64ac825e5b3b59309b0eacb63363dab00ce Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Oct 2023 00:11:54 -0400 Subject: [PATCH 190/206] fix compat error --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index eb4d9d59..490782cb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -38,7 +38,7 @@ Random = "1" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" -Statistics = "1.9" +Statistics = "1" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" From 92d148988e1fcd9ca9e67ff084fa11d5c5960e69 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Oct 2023 00:12:12 -0400 Subject: [PATCH 191/206] fix documentation for `optimize` to be single line --- src/optimize.jl | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 1a39cb41..7e0032dc 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -1,33 +1,18 @@ """ - optimize( - problem, - objective ::AbstractVariationalObjective, - restructure, - param_init, - max_iter ::Int, - objargs...; - kwargs... - ) - optimize( - problem, - objective ::AbstractVariationalObjective, - variational_dist_init, - max_iter ::Int, - objargs...; - kwargs... - ) + optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...) + optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...) Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. # Arguments -- `objective`: Variational Objective. +- `objective::AbstractVariationalObjective`: Variational Objective. - `param_init`: Initial value of the variational parameters. - `restruct`: Function that reconstructs the variational approximation from the flattened parameters. - `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. -- `max_iter`: Maximum number of iterations. +- `max_iter::Int`: Maximum number of iterations. - `objargs...`: Arguments to be passed to `objective`. # Keyword Arguments From a03e955245c20800580e541d0402a70d4588235a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Oct 2023 00:15:41 -0400 Subject: [PATCH 192/206] refactor remove begin end for one-liner --- test/interface/optimize.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 1384a4b4..3459b4c3 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -63,9 +63,7 @@ using Test rng = StableRNG(seed) test_values = rand(rng, T) - callback(; stat, args...) = begin - (test_value = test_values[stat.iteration],) - end + callback(; stat, args...) = (test_value = test_values[stat.iteration],) rng = StableRNG(seed) _, stats, _ = optimize( From ff83c036a3c2ee5f8d1d33f62d1894592a00497b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 10 Nov 2023 02:38:04 -0500 Subject: [PATCH 193/206] refactor create unified interface for estimating objectives --- src/AdvancedVI.jl | 24 +++++++++++++++++++++- src/objectives/elbo/advi.jl | 40 ++++++++++++++++++++----------------- test/interface/advi.jl | 12 +++++------ 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 54c2b1eb..b1decc4a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -46,7 +46,7 @@ function value_and_gradient! end Abstract type for the VI algorithms supported by `AdvancedVI`. # Implementations -To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`. +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. Also, it should provide gradients by implementing the function `estimate_gradient!`. If the estimator is stateful, it can implement `init` to initialize the state. """ @@ -71,6 +71,28 @@ init( ::Any ) = nothing +""" + estimate_objective([rng,] obj, q, prob, kwargs...) + +Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `q`: Variational approximation. + +# Keyword Arguments +For the keywword arguments, refer to the respective documentation for each variational objective. + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective end + +export estimate_objective + + """ estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 90e84b56..ef339cfd 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -51,19 +51,21 @@ Base.show(io::IO, advi::ADVI) = Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample). """ -function (advi::ADVI)( +function estimate_objective_with_samples( + advi::ADVI, + q ::Distributions.ContinuousMultivariateDistribution, prob, - q ::Distributions.ContinuousMultivariateDistribution, - zs::AbstractMatrix + zs ::AbstractMatrix ) 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs)) ℍ = advi.entropy(q, zs) 𝔼ℓ + ℍ end -function (advi::ADVI)( - prob, +function estimate_objective_with_samples( + advi ::ADVI, q_trans::Bijectors.TransformedDistribution, + prob, ηs ::AbstractMatrix ) @unpack dist, transform = q_trans @@ -78,35 +80,37 @@ function (advi::ADVI)( end """ - (advi::ADVI)( - [rng], prob, q; n_samples::Int = advi.n_samples + estimate_objective( + advi::ADVI, [rng], prob, q; n_samples::Int = advi.n_samples ) Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples. """ -function (advi::ADVI)( +function estimate_objective( rng ::Random.AbstractRNG, - prob, - q ::ContinuousDistribution; + advi ::ADVI, + q ::ContinuousDistribution, + prob; n_samples::Int = advi.n_samples ) zs = rand(rng, q, n_samples) - advi(prob, q, zs) + estimate_objective_with_samples(advi, q, prob, zs) end -function (advi::ADVI)( +function estimate_objective( rng ::Random.AbstractRNG, - prob, - q_trans ::Bijectors.TransformedDistribution; + advi ::ADVI, + q_trans ::Bijectors.TransformedDistribution, + prob; n_samples::Int = advi.n_samples ) q = q_trans.dist ηs = rand(rng, q, n_samples) - advi(prob, q_trans, ηs) + estimate_objective_with_samples(advi, q_trans, prob, ηs) end -(advi::ADVI)(prob, q::Distribution; n_samples::Int = advi.n_samples) = - advi(Random.default_rng(), prob, q; n_samples) +estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) = + estimate_objective(Random.default_rng(), advi, q, prob; n_samples) function estimate_gradient!( rng ::Random.AbstractRNG, @@ -122,7 +126,7 @@ function estimate_gradient!( q_trans = restructure(λ′) q = q_trans.dist ηs = rand(rng, q, advi.n_samples) - -advi(prob, q_trans, ηs) + -estimate_objective_with_samples(advi, q_trans, prob, ηs) end value_and_gradient!(adbackend, f, λ, out) diff --git a/test/interface/advi.jl b/test/interface/advi.jl index 16db09ca..1df396e4 100644 --- a/test/interface/advi.jl +++ b/test/interface/advi.jl @@ -16,16 +16,16 @@ using Test obj = ADVI(10) rng = StableRNG(seed) - elbo_ref = obj(rng, model, q₀_z; n_samples=10^4) + elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) @testset "determinism" begin rng = StableRNG(seed) - elbo = obj(rng, model, q₀_z; n_samples=10^4) + elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) @test elbo == elbo_ref end @testset "default_rng" begin - elbo = obj(model, q₀_z; n_samples=10^4) + elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4) @test elbo ≈ elbo_ref rtol=0.1 end end @@ -39,16 +39,16 @@ using Test obj = ADVI(10) rng = StableRNG(seed) - elbo_ref = obj(rng, model, q₀_z; n_samples=10^4) + elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) @testset "determinism" begin rng = StableRNG(seed) - elbo = obj(rng, model, q₀_z; n_samples=10^4) + elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) @test elbo == elbo_ref end @testset "default_rng" begin - elbo = obj(model, q₀_z; n_samples=10^4) + elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4) @test elbo ≈ elbo_ref rtol=0.1 end end From aecc655dcfb87b6f09d0278e405935afe162db60 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 10 Nov 2023 03:01:45 -0500 Subject: [PATCH 194/206] refactor unify interface for entropy estimator, fix advi docs --- src/objectives/elbo/advi.jl | 62 ++++++++++++++++++++++------------ src/objectives/elbo/entropy.jl | 43 +++++++++++++++++++---- 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index ef339cfd..54024db8 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -45,20 +45,28 @@ Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") """ - (advi::ADVI)( - [rng], prob, q, zs::AbstractMatrix - ) + estimate_objective_with_samples(obj, prob, q, zs) + +Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo samples. + +# Arguments +- `advi::ADVI`: ADVI objective. +- `q`: Variational approximation +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `mc_samples::AbstractMatrix`: Samples to be used to estimate the energy. (Each column is a single sample.) + +# Returns +- `obj_est`: Estimate of the objective value. -Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample). """ function estimate_objective_with_samples( - advi::ADVI, - q ::Distributions.ContinuousMultivariateDistribution, + advi ::ADVI, + q ::Distributions.ContinuousMultivariateDistribution, prob, - zs ::AbstractMatrix + mc_samples::AbstractMatrix ) - 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs)) - ℍ = advi.entropy(q, zs) + 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples)) + ℍ = estimate_entropy(advi.entropy, mc_samples, q) 𝔼ℓ + ℍ end @@ -66,25 +74,34 @@ function estimate_objective_with_samples( advi ::ADVI, q_trans::Bijectors.TransformedDistribution, prob, - ηs ::AbstractMatrix + mc_samples_unconstr::AbstractMatrix ) @unpack dist, transform = q_trans q = dist b⁻¹ = transform - 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, ηᵢ) - LogDensityProblems.logdensity(prob, zᵢ) + logdetjacᵢ + 𝔼ℓ = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr + mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, mc_sample_unconstr) + LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ end - ℍ = advi.entropy(q, ηs) + ℍ = estimate_entropy(advi.entropy, mc_samples_unconstr, q) 𝔼ℓ + ℍ end """ - estimate_objective( - advi::ADVI, [rng], prob, q; n_samples::Int = advi.n_samples - ) + estimate_objective([rng,] advi, q, prob; n_samples) + +Estimate the ELBO using the ADVI formulation. + +# Arguments +- `advi::ADVI`: ADVI objective. +- `q`: Variational approximation +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int = advi.n_samples`: Number of samples to be used to estimate the objective. -Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples. +# Returns +- `obj_est`: Estimate of the objective value. """ function estimate_objective( rng ::Random.AbstractRNG, @@ -93,8 +110,8 @@ function estimate_objective( prob; n_samples::Int = advi.n_samples ) - zs = rand(rng, q, n_samples) - estimate_objective_with_samples(advi, q, prob, zs) + mc_samples = rand(rng, q, n_samples) + estimate_objective_with_samples(advi, q, prob, mc_samples) end function estimate_objective( @@ -105,8 +122,8 @@ function estimate_objective( n_samples::Int = advi.n_samples ) q = q_trans.dist - ηs = rand(rng, q, n_samples) - estimate_objective_with_samples(advi, q_trans, prob, ηs) + mc_unconstr_samples = rand(rng, q, n_samples) + estimate_objective_with_samples(advi, q_trans, prob, mc_unconstr_samples) end estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) = @@ -122,6 +139,7 @@ function estimate_gradient!( restructure, est_state, ) + q_trans_stop = restructure(λ) function f(λ′) q_trans = restructure(λ′) q = q_trans.dist diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 63854ec0..48dad275 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,15 +1,44 @@ +""" + estimate_entropy(entropy_estimator, mc_samples, q) + +Estimate the entropy of `q`. + +# Arguments +- `entropy_estimator`: Entropy estimation strategy. +- `q`: Variational approximation. +- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" + +function estimate_entropy end + + +""" + ClosedFormEntropy() + +Use closed-form expression of entropy. + +# Requirements +- `q` implements `entropy`. + +# References +* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. +* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +""" struct ClosedFormEntropy <: AbstractEntropyEstimator end -function (::ClosedFormEntropy)(q, ::AbstractMatrix) +function estimate_entropy(::ClosedFormEntropy, ::Any, q) entropy(q) end struct MonteCarloEntropy <: AbstractEntropyEstimator end -function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) - mean(eachcol(ηs)) do ηᵢ - -logpdf(q, ηᵢ) +function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q) + mean(eachcol(mc_samples)) do mc_sample + -logpdf(q, mc_sample) end end @@ -27,8 +56,8 @@ The "sticking the landing" entropy estimator. """ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix) - ChainRulesCore.@ignore_derivatives mean(eachcol(ηs)) do ηᵢ - -logpdf(q, ηᵢ) +function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q) + ChainRulesCore.@ignore_derivatives mean(eachcol(mc_samples)) do mc_sample + -logpdf(q, mc_sample) end end From a8d532ae33de54c295f611ae4c00bf53ab764e07 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 10 Nov 2023 03:25:00 -0500 Subject: [PATCH 195/206] fix STL estimator to use manually stopped gradients instead --- src/objectives/elbo/advi.jl | 73 +++++++++++++++++++++------------- src/objectives/elbo/entropy.jl | 13 +++--- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 54024db8..8b7a2771 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -51,7 +51,7 @@ Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo sam # Arguments - `advi::ADVI`: ADVI objective. -- `q`: Variational approximation +- `q`: Variational approximation. - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `mc_samples::AbstractMatrix`: Samples to be used to estimate the energy. (Each column is a single sample.) @@ -59,34 +59,64 @@ Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo sam - `obj_est`: Estimate of the objective value. """ +function estimate_objective_with_samples( + advi ::ADVI, + q ::Union{Distributions.ContinuousMultivariateDistribution, + Bijectors.TransformedDistribution}, + prob, + mc_samples::AbstractMatrix +) + estimate_objective_with_samples(advi, q, q, prob, mc_samples) +end + + function estimate_objective_with_samples( advi ::ADVI, q ::Distributions.ContinuousMultivariateDistribution, + q_stop ::Distributions.ContinuousMultivariateDistribution, prob, mc_samples::AbstractMatrix ) 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples)) - ℍ = estimate_entropy(advi.entropy, mc_samples, q) + ℍ = estimate_entropy(advi.entropy, mc_samples, q, q_stop) 𝔼ℓ + ℍ end function estimate_objective_with_samples( - advi ::ADVI, - q_trans::Bijectors.TransformedDistribution, + advi ::ADVI, + q_trans ::Bijectors.TransformedDistribution, + q_trans_stop::Bijectors.TransformedDistribution, prob, mc_samples_unconstr::AbstractMatrix ) @unpack dist, transform = q_trans - q = dist - b⁻¹ = transform - 𝔼ℓ = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr + q = dist + q_stop = q_trans_stop.dist + b⁻¹ = transform + 𝔼ℓ = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, mc_sample_unconstr) LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ end - ℍ = estimate_entropy(advi.entropy, mc_samples_unconstr, q) + ℍ = estimate_entropy(advi.entropy, mc_samples_unconstr, q, q_stop) 𝔼ℓ + ℍ end +function rand_uncontrained_samples( + rng ::Random.AbstractRNG, + q ::ContinuousDistribution, + n_samples::Int, +) + rand(rng, q, n_samples) +end + +function rand_uncontrained_samples( + rng ::Random.AbstractRNG, + q_trans ::Bijectors.TransformedDistribution, + n_samples::Int, +) + rand(rng, q_trans.dist, n_samples) +end + """ estimate_objective([rng,] advi, q, prob; n_samples) @@ -106,24 +136,12 @@ Estimate the ELBO using the ADVI formulation. function estimate_objective( rng ::Random.AbstractRNG, advi ::ADVI, - q ::ContinuousDistribution, + q, prob; n_samples::Int = advi.n_samples ) - mc_samples = rand(rng, q, n_samples) - estimate_objective_with_samples(advi, q, prob, mc_samples) -end - -function estimate_objective( - rng ::Random.AbstractRNG, - advi ::ADVI, - q_trans ::Bijectors.TransformedDistribution, - prob; - n_samples::Int = advi.n_samples -) - q = q_trans.dist - mc_unconstr_samples = rand(rng, q, n_samples) - estimate_objective_with_samples(advi, q_trans, prob, mc_unconstr_samples) + mc_samples_unconstr = rand_uncontrained_samples(rng, q, n_samples) + estimate_objective_with_samples(advi, q, prob, mc_samples_unconstr) end estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) = @@ -139,12 +157,11 @@ function estimate_gradient!( restructure, est_state, ) - q_trans_stop = restructure(λ) + q_stop = restructure(λ) function f(λ′) - q_trans = restructure(λ′) - q = q_trans.dist - ηs = rand(rng, q, advi.n_samples) - -estimate_objective_with_samples(advi, q_trans, prob, ηs) + q = restructure(λ′) + mc_samples_unconstr = rand_uncontrained_samples(rng, q, advi.n_samples) + -estimate_objective_with_samples(advi, q, q_stop, prob, mc_samples_unconstr) end value_and_gradient!(adbackend, f, λ, out) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 48dad275..461aa030 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,12 +1,13 @@ """ - estimate_entropy(entropy_estimator, mc_samples, q) + estimate_entropy(entropy_estimator, mc_samples, q, q_stop) Estimate the entropy of `q`. # Arguments - `entropy_estimator`: Entropy estimation strategy. - `q`: Variational approximation. +- `q_stop`: Variational approximation with "stopped gradients". - `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) # Returns @@ -30,13 +31,13 @@ Use closed-form expression of entropy. """ struct ClosedFormEntropy <: AbstractEntropyEstimator end -function estimate_entropy(::ClosedFormEntropy, ::Any, q) +function estimate_entropy(::ClosedFormEntropy, ::Any, q, ::Any) entropy(q) end struct MonteCarloEntropy <: AbstractEntropyEstimator end -function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q) +function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, ::Any) mean(eachcol(mc_samples)) do mc_sample -logpdf(q, mc_sample) end @@ -56,8 +57,8 @@ The "sticking the landing" entropy estimator. """ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q) - ChainRulesCore.@ignore_derivatives mean(eachcol(mc_samples)) do mc_sample - -logpdf(q, mc_sample) +function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, ::Any, q_stop) + mean(eachcol(mc_samples)) do mc_sample + -logpdf(q_stop, mc_sample) end end From 65e9b126a6c0038bac9f6904e8d851b211b528b3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 10 Nov 2023 03:25:15 -0500 Subject: [PATCH 196/206] add inference test for a non-bijector model --- test/inference/advi_distributionsad.jl | 81 +++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl index 01c7a96e..9919ce2b 100644 --- a/test/inference/advi_distributionsad.jl +++ b/test/inference/advi_distributionsad.jl @@ -4,6 +4,81 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test @testset "inference_advi" begin + @testset "distributionsad" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64], # Currently only tested against Float64 + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => ADVI(10), + :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + μ₀ = Zeros(realtype, n_dims) + L₀ = Diagonal(Ones(realtype, n_dims)) + q₀_z = TuringDiagMvNormal(μ₀, diag(L₀)) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = mean(q) + L = sqrt(cov(q)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = mean(q) + L = sqrt(cov(q)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q) + L_repl = sqrt(cov(q)) + @test μ == μ_repl + @test L == L_repl + end + end + end +end + +@testset "inference_bijectors_advi" begin @testset "distributionsad" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64], # Currently only tested against Float64 @@ -16,9 +91,9 @@ using Test ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - # :ReverseDiff => AutoReverseDiff(), - # :Zygote => AutoZygote(), - # :Enzyme => AutoEnzyme(), + #:ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) From 3691f160a7923d7a569cb93eb51719e8f4065beb Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 11 Nov 2023 00:24:08 -0500 Subject: [PATCH 197/206] refactor add indirections to handle STL and bijectors in ADVI --- src/objectives/elbo/advi.jl | 106 ++++++++++++++++----------------- src/objectives/elbo/entropy.jl | 24 ++++---- 2 files changed, 60 insertions(+), 70 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 8b7a2771..98f8ae99 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -39,82 +39,76 @@ struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObject n_samples::Int end -ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) = ADVI(entropy, n_samples) +ADVI( + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy() +) = ADVI(entropy, n_samples) Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") -""" - estimate_objective_with_samples(obj, prob, q, zs) +maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop -Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo samples. +maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q -# Arguments -- `advi::ADVI`: ADVI objective. -- `q`: Variational approximation. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `mc_samples::AbstractMatrix`: Samples to be used to estimate the energy. (Each column is a single sample.) - -# Returns -- `obj_est`: Estimate of the objective value. +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, mc_samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, mc_samples, q_maybe_stop) +end -""" -function estimate_objective_with_samples( - advi ::ADVI, - q ::Union{Distributions.ContinuousMultivariateDistribution, - Bijectors.TransformedDistribution}, - prob, - mc_samples::AbstractMatrix -) - estimate_objective_with_samples(advi, q, q, prob, mc_samples) +function estimate_energy_with_samples(::ADVI, mc_samples::AbstractMatrix, prob) + mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples)) end +function estimate_energy_with_samples_bijector(::ADVI, mc_samples::AbstractMatrix, invbij, prob) + mean(eachcol(mc_samples)) do mc_sample + mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(invbij, mc_sample) + LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ + end +end -function estimate_objective_with_samples( +function estimate_advi_maybe_stl_with_samples( advi ::ADVI, - q ::Distributions.ContinuousMultivariateDistribution, - q_stop ::Distributions.ContinuousMultivariateDistribution, - prob, - mc_samples::AbstractMatrix + q ::ContinuousDistribution, + q_stop ::ContinuousDistribution, + mc_samples::AbstractMatrix, + prob ) - 𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples)) - ℍ = estimate_entropy(advi.entropy, mc_samples, q, q_stop) - 𝔼ℓ + ℍ + energy = estimate_energy_with_samples(advi, mc_samples, prob) + entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop) + energy + entropy end -function estimate_objective_with_samples( +function estimate_advi_maybe_stl_with_samples( advi ::ADVI, q_trans ::Bijectors.TransformedDistribution, q_trans_stop::Bijectors.TransformedDistribution, - prob, - mc_samples_unconstr::AbstractMatrix + mc_samples ::AbstractMatrix, + prob ) - @unpack dist, transform = q_trans - q = dist - q_stop = q_trans_stop.dist - b⁻¹ = transform - 𝔼ℓ = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr - mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, mc_sample_unconstr) - LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ - end - ℍ = estimate_entropy(advi.entropy, mc_samples_unconstr, q, q_stop) - 𝔼ℓ + ℍ + q = q_trans.dist + invbij = q_trans.transform + q_stop = q_trans_stop.dist + energy = estimate_energy_with_samples_bijector(advi, mc_samples, invbij, prob) + entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop) + energy + entropy end -function rand_uncontrained_samples( +rand_unconstrained( rng ::Random.AbstractRNG, q ::ContinuousDistribution, - n_samples::Int, -) - rand(rng, q, n_samples) -end + n_samples::Int +) = rand(rng, q, n_samples) -function rand_uncontrained_samples( +rand_unconstrained( rng ::Random.AbstractRNG, - q_trans ::Bijectors.TransformedDistribution, - n_samples::Int, -) - rand(rng, q_trans.dist, n_samples) + q ::Bijectors.TransformedDistribution, + n_samples::Int +) = rand(rng, q.dist, n_samples) + +function estimate_advi_maybe_stl(rng::Random.AbstractRNG, advi::ADVI, q, q_stop, prob) + mc_samples = rand_unconstrained(rng, q, advi.n_samples) + estimate_advi_maybe_stl_with_samples(advi, q, q_stop, mc_samples, prob) end """ @@ -140,8 +134,8 @@ function estimate_objective( prob; n_samples::Int = advi.n_samples ) - mc_samples_unconstr = rand_uncontrained_samples(rng, q, n_samples) - estimate_objective_with_samples(advi, q, prob, mc_samples_unconstr) + mc_samples = rand_unconstrained(rng, q, n_samples) + estimate_advi_maybe_stl_with_samples(advi, q, q, mc_samples, prob) end estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) = @@ -160,8 +154,8 @@ function estimate_gradient!( q_stop = restructure(λ) function f(λ′) q = restructure(λ′) - mc_samples_unconstr = rand_uncontrained_samples(rng, q, advi.n_samples) - -estimate_objective_with_samples(advi, q, q_stop, prob, mc_samples_unconstr) + elbo = estimate_advi_maybe_stl(rng, advi, q, q_stop, prob) + -elbo end value_and_gradient!(adbackend, f, λ, out) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 461aa030..6fa3095e 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,13 +1,12 @@ """ - estimate_entropy(entropy_estimator, mc_samples, q, q_stop) + estimate_entropy(entropy_estimator, mc_samples, q) Estimate the entropy of `q`. # Arguments - `entropy_estimator`: Entropy estimation strategy. - `q`: Variational approximation. -- `q_stop`: Variational approximation with "stopped gradients". - `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) # Returns @@ -16,7 +15,6 @@ Estimate the entropy of `q`. function estimate_entropy end - """ ClosedFormEntropy() @@ -31,18 +29,10 @@ Use closed-form expression of entropy. """ struct ClosedFormEntropy <: AbstractEntropyEstimator end -function estimate_entropy(::ClosedFormEntropy, ::Any, q, ::Any) +function estimate_entropy(::ClosedFormEntropy, ::Any, q) entropy(q) end -struct MonteCarloEntropy <: AbstractEntropyEstimator end - -function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, ::Any) - mean(eachcol(mc_samples)) do mc_sample - -logpdf(q, mc_sample) - end -end - """ StickingTheLandingEntropy() @@ -57,8 +47,14 @@ The "sticking the landing" entropy estimator. """ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, ::Any, q_stop) +struct MonteCarloEntropy <: AbstractEntropyEstimator end + +function estimate_entropy( + ::Union{MonteCarloEntropy, StickingTheLandingEntropy}, + mc_samples::AbstractMatrix, + q +) mean(eachcol(mc_samples)) do mc_sample - -logpdf(q_stop, mc_sample) + -logpdf(q, mc_sample) end end From a063583c8e5b9cb83efd006110f662591e68c0ba Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 11 Nov 2023 00:44:07 -0500 Subject: [PATCH 198/206] refactor split inference tests for advi+distributionsad --- test/inference/advi_distributionsad.jl | 208 ++++++------------ .../advi_distributionsad_bijectors.jl | 81 +++++++ test/runtests.jl | 1 + 3 files changed, 146 insertions(+), 144 deletions(-) create mode 100644 test/inference/advi_distributionsad_bijectors.jl diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl index 9919ce2b..e82a9ec0 100644 --- a/test/inference/advi_distributionsad.jl +++ b/test/inference/advi_distributionsad.jl @@ -3,156 +3,76 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "inference_advi" begin - @testset "distributionsad" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64], # Currently only tested against Float64 - (modelname, modelconstr) ∈ Dict( - :Normal=> normal_meanfield, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => ADVI(10), - :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), - ), - (adbackname, adbackend) ∈ Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), +@testset "inference_advi_distributionsad" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => ADVI(10), + :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + μ₀ = Zeros(realtype, n_dims) + L₀ = Diagonal(Ones(realtype, n_dims)) + q₀_z = TuringDiagMvNormal(μ₀, diag(L₀)) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, ) - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + μ = mean(q) + L = sqrt(cov(q)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) - q₀_z = TuringDiagMvNormal(μ₀, diag(L₀)) - - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - - μ = mean(q) - L = sqrt(cov(q)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ Δλ₀/T^(1/4) - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ = mean(q) - L = sqrt(cov(q)) - - rng_repl = StableRNG(seed) - q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ_repl = mean(q) - L_repl = sqrt(cov(q)) - @test μ == μ_repl - @test L == L_repl - end + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) end - end -end -@testset "inference_bijectors_advi" begin - @testset "distributionsad" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64], # Currently only tested against Float64 - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => ADVI(10), - :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), - ), - (adbackname, adbackend) ∈ Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, ) - - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) - - q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - - μ = mean(q.dist) - L = sqrt(cov(q.dist)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ Δλ₀/T^(1/4) - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) - - rng_repl = StableRNG(seed) - q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ_repl = mean(q.dist) - L_repl = sqrt(cov(q.dist)) - @test μ == μ_repl - @test L == L_repl - end + μ = mean(q) + L = sqrt(cov(q)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q) + L_repl = sqrt(cov(q)) + @test μ == μ_repl + @test L == L_repl end end end + diff --git a/test/inference/advi_distributionsad_bijectors.jl b/test/inference/advi_distributionsad_bijectors.jl new file mode 100644 index 00000000..29602fe7 --- /dev/null +++ b/test/inference/advi_distributionsad_bijectors.jl @@ -0,0 +1,81 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference_advi_distributionsad_bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => ADVI(10), + :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ₀ = Zeros(realtype, n_dims) + L₀ = Diagonal(Ones(realtype, n_dims)) + + q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) + q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = mean(q.dist) + L = sqrt(cov(q.dist)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = mean(q.dist) + L = sqrt(cov(q.dist)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q.dist) + L_repl = sqrt(cov(q.dist)) + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6fda0be8..757a931d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,3 +40,4 @@ include("interface/optimize.jl") include("interface/advi.jl") include("inference/advi_distributionsad.jl") +include("inference/advi_distributionsad_bijectors.jl") From 316b629eb965a591019b7149bbcf7fc72e613b9b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 21 Nov 2023 01:50:41 -0500 Subject: [PATCH 199/206] refactor rename advi to repgradelbo and not use bijectors directly --- Project.toml | 2 - src/AdvancedVI.jl | 9 +- src/objectives/elbo/advi.jl | 166 ------------------ src/objectives/elbo/repgradelbo.jl | 126 +++++++++++++ src/utils.jl | 3 + test/Project.toml | 2 - .../advi_distributionsad_bijectors.jl | 81 --------- ...nsad.jl => repgradelbo_distributionsad.jl} | 20 +-- test/interface/advi.jl | 55 ------ test/interface/optimize.jl | 22 ++- test/interface/repgradelbo.jl | 28 +++ test/models/normallognormal.jl | 65 ------- test/runtests.jl | 8 +- 13 files changed, 182 insertions(+), 405 deletions(-) delete mode 100644 src/objectives/elbo/advi.jl create mode 100644 src/objectives/elbo/repgradelbo.jl delete mode 100644 test/inference/advi_distributionsad_bijectors.jl rename test/inference/{advi_distributionsad.jl => repgradelbo_distributionsad.jl} (78%) delete mode 100644 test/interface/advi.jl create mode 100644 test/interface/repgradelbo.jl delete mode 100644 test/models/normallognormal.jl diff --git a/Project.toml b/Project.toml index 70041561..7799d505 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -36,7 +35,6 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1, 0.2" Accessors = "0.1" -Bijectors = "0.12, 0.13" ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.87" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index b1decc4a..bb5b6e85 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -11,7 +11,6 @@ using Functors using Optimisers using DocStringExtensions - using ProgressMeter using LinearAlgebra @@ -21,7 +20,6 @@ using ADTypes, DiffResults using ChainRulesCore using FillArrays -using Bijectors using StatsBase @@ -115,18 +113,17 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ """ function estimate_gradient! end -# ADVI-specific interfaces +# ELBO-specific interfaces abstract type AbstractEntropyEstimator end export - ADVI, + RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy -# entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") -include("objectives/elbo/advi.jl") +include("objectives/elbo/repgradelbo.jl") # Optimization Routine diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl deleted file mode 100644 index 98f8ae99..00000000 --- a/src/objectives/elbo/advi.jl +++ /dev/null @@ -1,166 +0,0 @@ - -""" - ADVI(n_samples; kwargs...) - -Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. -This computes the evidence lower-bound (ELBO) through the ADVI formulation: -```math -\\begin{aligned} -\\mathrm{ADVI}\\left(\\lambda\\right) -&\\triangleq -\\mathbb{E}_{\\eta \\sim q_{\\lambda}}\\left[ - \\log \\pi\\left( \\phi^{-1}\\left( \\eta \\right) \\right) - + - \\log \\lvert J_{\\phi^{-1}}\\left(\\eta\\right) \\rvert -\\right] -+ \\mathbb{H}\\left(q_{\\lambda}\\right), -\\end{aligned} -``` -where ``\\phi^{-1}`` is an "inverse bijector." - -# Arguments -- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. - -# Keyword Arguments -- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) - -# Requirements -- ``q_{\\lambda}`` implements `rand`. -- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. - -Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. - -# References -* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. -* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. -""" -struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - entropy ::EntropyEst - n_samples::Int -end - -ADVI( - n_samples::Int; - entropy ::AbstractEntropyEstimator = ClosedFormEntropy() -) = ADVI(entropy, n_samples) - -Base.show(io::IO, advi::ADVI) = - print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") - -maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop - -maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q - -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, mc_samples, q, q_stop) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, mc_samples, q_maybe_stop) -end - -function estimate_energy_with_samples(::ADVI, mc_samples::AbstractMatrix, prob) - mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples)) -end - -function estimate_energy_with_samples_bijector(::ADVI, mc_samples::AbstractMatrix, invbij, prob) - mean(eachcol(mc_samples)) do mc_sample - mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(invbij, mc_sample) - LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ - end -end - -function estimate_advi_maybe_stl_with_samples( - advi ::ADVI, - q ::ContinuousDistribution, - q_stop ::ContinuousDistribution, - mc_samples::AbstractMatrix, - prob -) - energy = estimate_energy_with_samples(advi, mc_samples, prob) - entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop) - energy + entropy -end - -function estimate_advi_maybe_stl_with_samples( - advi ::ADVI, - q_trans ::Bijectors.TransformedDistribution, - q_trans_stop::Bijectors.TransformedDistribution, - mc_samples ::AbstractMatrix, - prob -) - q = q_trans.dist - invbij = q_trans.transform - q_stop = q_trans_stop.dist - energy = estimate_energy_with_samples_bijector(advi, mc_samples, invbij, prob) - entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop) - energy + entropy -end - -rand_unconstrained( - rng ::Random.AbstractRNG, - q ::ContinuousDistribution, - n_samples::Int -) = rand(rng, q, n_samples) - -rand_unconstrained( - rng ::Random.AbstractRNG, - q ::Bijectors.TransformedDistribution, - n_samples::Int -) = rand(rng, q.dist, n_samples) - -function estimate_advi_maybe_stl(rng::Random.AbstractRNG, advi::ADVI, q, q_stop, prob) - mc_samples = rand_unconstrained(rng, q, advi.n_samples) - estimate_advi_maybe_stl_with_samples(advi, q, q_stop, mc_samples, prob) -end - -""" - estimate_objective([rng,] advi, q, prob; n_samples) - -Estimate the ELBO using the ADVI formulation. - -# Arguments -- `advi::ADVI`: ADVI objective. -- `q`: Variational approximation -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - -# Keyword Arguments -- `n_samples::Int = advi.n_samples`: Number of samples to be used to estimate the objective. - -# Returns -- `obj_est`: Estimate of the objective value. -""" -function estimate_objective( - rng ::Random.AbstractRNG, - advi ::ADVI, - q, - prob; - n_samples::Int = advi.n_samples -) - mc_samples = rand_unconstrained(rng, q, n_samples) - estimate_advi_maybe_stl_with_samples(advi, q, q, mc_samples, prob) -end - -estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) = - estimate_objective(Random.default_rng(), advi, q, prob; n_samples) - -function estimate_gradient!( - rng ::Random.AbstractRNG, - advi ::ADVI, - adbackend ::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult, - prob, - λ, - restructure, - est_state, -) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) - elbo = estimate_advi_maybe_stl(rng, advi, q, q_stop, prob) - -elbo - end - value_and_gradient!(adbackend, f, λ, out) - - nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - - out, nothing, stat -end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl new file mode 100644 index 00000000..09ba1a79 --- /dev/null +++ b/src/objectives/elbo/repgradelbo.jl @@ -0,0 +1,126 @@ + +""" + RepGradELBO(n_samples; kwargs...) + +Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014]. +This computes the evidence lower-bound (ELBO) through the formulation: +```math +\\begin{aligned} +\\mathrm{ELBO}\\left(\\lambda\\right) +&\\triangleq +\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ + \\log \\pi\\left(z\\right) +\\right] ++ \\mathbb{H}\\left(q_{\\lambda}\\right), +\\end{aligned} +``` + +# Arguments +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. + +# References +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML. +[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML. +[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR. +""" +struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + entropy ::EntropyEst + n_samples::Int +end + +RepGradELBO( + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy() +) = RepGradELBO(entropy, n_samples) + +Base.show(io::IO, obj::RepGradELBO) = + print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))") + +maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop + +maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q + +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + +function estimate_energy_with_samples(::RepGradELBO, samples, prob) + mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) +end + +function estimate_repgradelbo_maybe_stl_with_samples( + obj::RepGradELBO, q, q_stop, samples::AbstractMatrix, prob +) + energy = estimate_energy_with_samples(obj, samples, prob) + entropy = estimate_entropy_maybe_stl(obj.entropy, samples, q, q_stop) + energy + entropy +end + +function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELBO, q, q_stop, prob) + samples = rand(rng, q, obj.n_samples) + estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob) +end + +""" + estimate_objective([rng,] obj, q, prob; n_samples) + +Estimate the ELBO using the reparameterization gradient formulation. + +# Arguments +- `obj::RepGradELBO`: The ELBO objective. +- `q`: Variational approximation +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int = obj.n_samples`: Number of samples to be used to estimate the objective. + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + obj::RepGradELBO, + q, + prob; + n_samples::Int = obj.n_samples +) + samples = rand(rng, q, n_samples) + estimate_repgradelbo_maybe_stl_with_samples(obj, q, q, samples, prob) +end + +estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = + estimate_objective(Random.default_rng(), obj, q, prob; n_samples) + +function estimate_gradient!( + rng ::Random.AbstractRNG, + obj ::RepGradELBO, + adbackend::ADTypes.AbstractADType, + out ::DiffResults.MutableDiffResult, + prob, + λ, + restructure, + est_state, +) + q_stop = restructure(λ) + function f(λ′) + q = restructure(λ′) + elbo = estimate_repgradelbo_maybe_stl(rng, obj, q, q_stop, prob) + -elbo + end + value_and_gradient!(adbackend, f, λ, out) + + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + out, nothing, stat +end diff --git a/src/utils.jl b/src/utils.jl index 8dd7c37b..76637fa3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -21,3 +21,6 @@ function maybe_init_objective( haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure) end +eachsample(samples::AbstractMatrix) = eachcol(samples) + +eachsample(samples::AbstractVector) = samples diff --git a/test/Project.toml b/test/Project.toml index 490782cb..7d0bf2d2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -23,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1" -Bijectors = "0.13.6" Distributions = "0.25.100" DistributionsAD = "0.6.45" Enzyme = "0.11.7" diff --git a/test/inference/advi_distributionsad_bijectors.jl b/test/inference/advi_distributionsad_bijectors.jl deleted file mode 100644 index 29602fe7..00000000 --- a/test/inference/advi_distributionsad_bijectors.jl +++ /dev/null @@ -1,81 +0,0 @@ - -const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false - -using Test - -@testset "inference_advi_distributionsad_bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64, Float32], - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => ADVI(10), - :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), - ), - (adbackname, adbackend) ∈ Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), - ) - - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) - - q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - - μ = mean(q.dist) - L = sqrt(cov(q.dist)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ Δλ₀/T^(1/4) - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) - - rng_repl = StableRNG(seed) - q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ_repl = mean(q.dist) - L_repl = sqrt(cov(q.dist)) - @test μ == μ_repl - @test L == L_repl - end - end -end diff --git a/test/inference/advi_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl similarity index 78% rename from test/inference/advi_distributionsad.jl rename to test/inference/repgradelbo_distributionsad.jl index e82a9ec0..29cb2d83 100644 --- a/test/inference/advi_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -3,15 +3,15 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "inference_advi_distributionsad" begin +@testset "inference RepGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64, Float32], (modelname, modelconstr) ∈ Dict( :Normal=> normal_meanfield, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => ADVI(10), - :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), @@ -28,14 +28,14 @@ using Test T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) - q₀_z = TuringDiagMvNormal(μ₀, diag(L₀)) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + q0 = TuringDiagMvNormal(μ0, diag(L0)) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, @@ -53,7 +53,7 @@ using Test @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, @@ -63,7 +63,7 @@ using Test rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; + rng_repl, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, diff --git a/test/interface/advi.jl b/test/interface/advi.jl deleted file mode 100644 index 1df396e4..00000000 --- a/test/interface/advi.jl +++ /dev/null @@ -1,55 +0,0 @@ - -using Test - -@testset "advi" begin - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - @testset "with bijector" begin - modelstats = normallognormal_meanfield(rng, Float64) - - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - b⁻¹ = Bijectors.bijector(model) |> inverse - q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - obj = ADVI(10) - - rng = StableRNG(seed) - elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - - @testset "determinism" begin - rng = StableRNG(seed) - elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - @test elbo == elbo_ref - end - - @testset "default_rng" begin - elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4) - @test elbo ≈ elbo_ref rtol=0.1 - end - end - - @testset "without bijector" begin - modelstats = normal_meanfield(rng, Float64) - - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - q₀_z = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - - obj = ADVI(10) - rng = StableRNG(seed) - elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - - @testset "determinism" begin - rng = StableRNG(seed) - elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - @test elbo == elbo_ref - end - - @testset "default_rng" begin - elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4) - @test elbo ≈ elbo_ref rtol=0.1 - end - end -end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 3459b4c3..6e69616b 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -1,27 +1,25 @@ using Test -@testset "optimize" begin +@testset "interface optimize" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) T = 1000 - modelstats = normallognormal_meanfield(rng, Float64) + modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations - b⁻¹ = Bijectors.bijector(model) |> inverse - q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - obj = ADVI(10) + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + obj = RepGradELBO(10) adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) rng = StableRNG(seed) q_ref, stats_ref, _ = optimize( - rng, model, obj, q₀_z, T; + rng, model, obj, q0, T; optimizer, show_progress = false, adbackend, @@ -30,13 +28,13 @@ using Test @testset "default_rng" begin optimize( - model, obj, q₀_z, T; + model, obj, q0, T; optimizer, show_progress = false, adbackend, ) - λ₀, re = Optimisers.destructure(q₀_z) + λ₀, re = Optimisers.destructure(q0) optimize( model, obj, re, λ₀, T; optimizer, @@ -46,7 +44,7 @@ using Test end @testset "restructure" begin - λ₀, re = Optimisers.destructure(q₀_z) + λ₀, re = Optimisers.destructure(q0) rng = StableRNG(seed) λ, stats, _ = optimize( @@ -67,7 +65,7 @@ using Test rng = StableRNG(seed) _, stats, _ = optimize( - rng, model, obj, q₀_z, T; + rng, model, obj, q0, T; show_progress = false, adbackend, callback @@ -82,7 +80,7 @@ using Test T_last = T - T_first q_first, _, state = optimize( - rng, model, obj, q₀_z, T_first; + rng, model, obj, q0, T_first; optimizer, show_progress = false, adbackend diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl new file mode 100644 index 00000000..61ff0111 --- /dev/null +++ b/test/interface/repgradelbo.jl @@ -0,0 +1,28 @@ + +using Test + +@testset "interface RepGradELBO" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + + obj = RepGradELBO(10) + rng = StableRNG(seed) + elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = estimate_objective(obj, q0, model; n_samples=10^4) + @test elbo ≈ elbo_ref rtol=0.1 + end +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl deleted file mode 100644 index c2cb2b0e..00000000 --- a/test/models/normallognormal.jl +++ /dev/null @@ -1,65 +0,0 @@ - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L_y = tril(I + ones(realtype, n_dims, n_dims))/2 - Σ_y = L_y*L_y' |> Hermitian - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) - - Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) - Σ[1,1] = σ_x^2 - Σ[2:end,2:end] = Σ_y - Σ = Σ |> Hermitian - - μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L - - TestModel(model, μ, L, n_dims+1, false) -end - -function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) - - μ = vcat(μ_x, μ_y) - L = vcat(σ_x, σ_y) |> Diagonal - - TestModel(model, μ, L, n_dims+1, true) -end diff --git a/test/runtests.jl b/test/runtests.jl index 757a931d..a855541c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,6 @@ using Functors using DistributionsAD @functor TuringDiagMvNormal -using Bijectors using LogDensityProblems using Optimisers using ADTypes @@ -30,14 +29,11 @@ struct TestModel{M,L,S} n_dims::Int is_meanfield::Bool end - -include("models/normallognormal.jl") include("models/normal.jl") # Tests include("interface/ad.jl") include("interface/optimize.jl") -include("interface/advi.jl") +include("interface/repgradelbo.jl") -include("inference/advi_distributionsad.jl") -include("inference/advi_distributionsad_bijectors.jl") +include("inference/repgradelbo_distributionsad.jl") From 13b208868dc2b3d3d6d5fe9b38228fa2879684cd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 23 Nov 2023 00:44:57 -0500 Subject: [PATCH 200/206] fix documentation for estimate_objective --- src/AdvancedVI.jl | 20 ++++++++++++++++++-- src/objectives/elbo/repgradelbo.jl | 20 -------------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index bb5b6e85..d17a088c 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -70,7 +70,7 @@ init( ) = nothing """ - estimate_objective([rng,] obj, q, prob, kwargs...) + estimate_objective([rng,] obj, q, prob; kwargs...) Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. @@ -81,7 +81,8 @@ Estimate the variational objective `obj` targeting `prob` with respect to the va - `q`: Variational approximation. # Keyword Arguments -For the keywword arguments, refer to the respective documentation for each variational objective. +Depending on the objective, additional keyword arguments may apply. +Please refer to the respective documentation of each variational objective for more info. # Returns - `obj_est`: Estimate of the objective value. @@ -116,6 +117,21 @@ function estimate_gradient! end # ELBO-specific interfaces abstract type AbstractEntropyEstimator end +""" + estimate_entropy(entropy_estimator, mc_samples, q) + +Estimate the entropy of `q`. + +# Arguments +- `entropy_estimator`: Entropy estimation strategy. +- `q`: Variational approximation. +- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_entropy end + export RepGradELBO, ClosedFormEntropy, diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 09ba1a79..48a5461f 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -45,10 +45,6 @@ RepGradELBO( Base.show(io::IO, obj::RepGradELBO) = print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))") -maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop - -maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q - function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) estimate_entropy(entropy_estimator, samples, q_maybe_stop) @@ -71,22 +67,6 @@ function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELB estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob) end -""" - estimate_objective([rng,] obj, q, prob; n_samples) - -Estimate the ELBO using the reparameterization gradient formulation. - -# Arguments -- `obj::RepGradELBO`: The ELBO objective. -- `q`: Variational approximation -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - -# Keyword Arguments -- `n_samples::Int = obj.n_samples`: Number of samples to be used to estimate the objective. - -# Returns -- `obj_est`: Estimate of the objective value. -""" function estimate_objective( rng::Random.AbstractRNG, obj::RepGradELBO, From b0e1be14bed230f307d704641bdc911728367613 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 23 Nov 2023 02:05:48 -0500 Subject: [PATCH 201/206] refactor add indirection in repgradelbo for interacting with `q` --- src/objectives/elbo/entropy.jl | 20 +++----------- src/objectives/elbo/repgradelbo.jl | 44 ++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 6fa3095e..231b1652 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,20 +1,4 @@ -""" - estimate_entropy(entropy_estimator, mc_samples, q) - -Estimate the entropy of `q`. - -# Arguments -- `entropy_estimator`: Entropy estimation strategy. -- `q`: Variational approximation. -- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) - -# Returns -- `obj_est`: Estimate of the objective value. -""" - -function estimate_entropy end - """ ClosedFormEntropy() @@ -29,6 +13,8 @@ Use closed-form expression of entropy. """ struct ClosedFormEntropy <: AbstractEntropyEstimator end +maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q + function estimate_entropy(::ClosedFormEntropy, ::Any, q) entropy(q) end @@ -49,6 +35,8 @@ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end struct MonteCarloEntropy <: AbstractEntropyEstimator end +maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop + function estimate_entropy( ::Union{MonteCarloEntropy, StickingTheLandingEntropy}, mc_samples::AbstractMatrix, diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 48a5461f..28bd681f 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -50,21 +50,32 @@ function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, estimate_entropy(entropy_estimator, samples, q_maybe_stop) end -function estimate_energy_with_samples(::RepGradELBO, samples, prob) +function estimate_energy_with_samples(prob, samples) mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end -function estimate_repgradelbo_maybe_stl_with_samples( - obj::RepGradELBO, q, q_stop, samples::AbstractMatrix, prob -) - energy = estimate_energy_with_samples(obj, samples, prob) - entropy = estimate_entropy_maybe_stl(obj.entropy, samples, q, q_stop) - energy + entropy -end +""" + reparam_with_entropy(rng, n_samples, q, q_stop, ent_est) + +Draw `n_samples` from `q` and compute its entropy. -function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELBO, q, q_stop, prob) - samples = rand(rng, q, obj.n_samples) - estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob) +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `n_samples::Int`: Number of Monte Carlo samples +- `q`: Variational approximation. +- `q_stop`: `q` but with its gradient stopped. +- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) + +# Returns +- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution. +- `entropy`: An estimate (or exact value) of the differential entropy of `q`. +""" +function reparam_with_entropy( + rng::Random.AbstractRNG, n_samples::Int, q, q_stop, ent_est +) + samples = rand(rng, q, n_samples) + entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + samples, entropy end function estimate_objective( @@ -74,8 +85,9 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples = rand(rng, q, n_samples) - estimate_repgradelbo_maybe_stl_with_samples(obj, q, q, samples, prob) + samples, entropy = reparam_with_entropy(rng, n_samples, q, q, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + energy + entropy end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = @@ -89,12 +101,14 @@ function estimate_gradient!( prob, λ, restructure, - est_state, + state, ) q_stop = restructure(λ) function f(λ′) q = restructure(λ′) - elbo = estimate_repgradelbo_maybe_stl(rng, obj, q, q_stop, prob) + samples, entropy = reparam_with_entropy(rng, obj.n_samples, q, q_stop, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + elbo = energy + entropy -elbo end value_and_gradient!(adbackend, f, λ, out) From 7361ed4d5abafa46a3b16b74fec0be612d859d7e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 23 Nov 2023 02:37:42 -0500 Subject: [PATCH 202/206] add TransformedDistribution support as extension --- Project.toml | 4 + ext/AdvancedVIBijectorsExt.jl | 37 +++++++++ src/AdvancedVI.jl | 3 + test/Project.toml | 2 + .../repgradelbo_distributionsad_bijectors.jl | 81 +++++++++++++++++++ test/models/normallognormal.jl | 65 +++++++++++++++ test/runtests.jl | 3 + 7 files changed, 195 insertions(+) create mode 100644 ext/AdvancedVIBijectorsExt.jl create mode 100644 test/inference/repgradelbo_distributionsad_bijectors.jl create mode 100644 test/models/normallognormal.jl diff --git a/Project.toml b/Project.toml index 7799d505..f4ea1bcc 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -31,10 +32,12 @@ AdvancedVIEnzymeExt = "Enzyme" AdvancedVIForwardDiffExt = "ForwardDiff" AdvancedVIReverseDiffExt = "ReverseDiff" AdvancedVIZygoteExt = "Zygote" +AdvancedVIBijectorsExt = "Bijectors" [compat] ADTypes = "0.1, 0.2" Accessors = "0.1" +Bijectors = "0.13" ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.87" @@ -56,6 +59,7 @@ Zygote = "0.6.63" julia = "1.6" [extras] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl new file mode 100644 index 00000000..5d9dc774 --- /dev/null +++ b/ext/AdvancedVIBijectorsExt.jl @@ -0,0 +1,37 @@ + +module AdvancedVIBijectorsExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using Bijectors + using Random +else + using ..AdvancedVI + using ..Bijectors + using ..Random +end + +function AdvancedVI.reparam_with_entropy( + rng ::Random.AbstractRNG, + n_samples::Int, + q ::Bijectors.TransformedDistribution, + q_stop ::Bijectors.TransformedDistribution, + ent_est +) + transform = q.transform + q_base = q.dist + q_base_stop = q_stop.dist + ∑logabsdetjac = 0.0 + base_samples = rand(rng, q_base, n_samples) + samples = mapreduce(hcat, eachcol(base_samples)) do base_sample + sample, logabsdetjac = with_logabsdet_jacobian(transform, base_sample) + ∑logabsdetjac += logabsdetjac + sample + end + entropy_base = AdvancedVI.estimate_entropy_maybe_stl( + ent_est, base_samples, q_base, q_base_stop + ) + entropy = entropy_base + ∑logabsdetjac/n_samples + samples, entropy +end +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d17a088c..89f86696 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -158,6 +158,9 @@ end @static if !isdefined(Base, :get_extension) function __init__() + @require Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" begin + include("../ext/AdvancedVIBijectorsExt.jl") + end @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/AdvancedVIEnzymeExt.jl") end diff --git a/test/Project.toml b/test/Project.toml index 7d0bf2d2..a751b89d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -22,6 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1" +Bijectors = "0.13" Distributions = "0.25.100" DistributionsAD = "0.6.45" Enzyme = "0.11.7" diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_distributionsad_bijectors.jl new file mode 100644 index 00000000..9f1e3cc4 --- /dev/null +++ b/test/inference/repgradelbo_distributionsad_bijectors.jl @@ -0,0 +1,81 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD Bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ₀ = Zeros(realtype, n_dims) + L₀ = Diagonal(Ones(realtype, n_dims)) + + q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) + q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = mean(q.dist) + L = sqrt(cov(q.dist)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = mean(q.dist) + L = sqrt(cov(q.dist)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q.dist) + L_repl = sqrt(cov(q.dist)) + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl new file mode 100644 index 00000000..6615084b --- /dev/null +++ b/test/models/normallognormal.jl @@ -0,0 +1,65 @@ + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ_y = L_y*L_y' |> Hermitian + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + + μ = vcat(μ_x, μ_y) + L = vcat(σ_x, σ_y) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index a855541c..b14b8b2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Test using Test: @testset, @test +using Bijectors using Random, StableRNGs using Statistics using Distributions @@ -30,6 +31,7 @@ struct TestModel{M,L,S} is_meanfield::Bool end include("models/normal.jl") +include("models/normallognormal.jl") # Tests include("interface/ad.jl") @@ -37,3 +39,4 @@ include("interface/optimize.jl") include("interface/repgradelbo.jl") include("inference/repgradelbo_distributionsad.jl") +include("inference/repgradelbo_distributionsad_bijectors.jl") From d2e76143f18cfbe0631816e8af03b6b531256067 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 8 Dec 2023 02:24:30 -0500 Subject: [PATCH 203/206] Update src/objectives/elbo/repgradelbo.jl Co-authored-by: Tor Erlend Fjelde --- src/objectives/elbo/repgradelbo.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 28bd681f..7e093135 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -43,7 +43,11 @@ RepGradELBO( ) = RepGradELBO(entropy, n_samples) Base.show(io::IO, obj::RepGradELBO) = - print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))") + print(io, "RepGradELBO(entropy=") + print(io, obj.entropy) + print(io, ", n_samples=") + print(io, obj.n_samples) + print(io, ")") function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) From 77686b5c776de4e42637932b0310c36b6e4a8d86 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 8 Dec 2023 02:31:21 -0500 Subject: [PATCH 204/206] fix docstring for entropy estimator --- src/objectives/elbo/entropy.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 231b1652..6c5b4739 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,7 +5,7 @@ Use closed-form expression of entropy. # Requirements -- `q` implements `entropy`. +- The variational approximation implements `entropy`. # References * Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. @@ -25,7 +25,7 @@ end The "sticking the landing" entropy estimator. # Requirements -- `q` implements `logpdf`. +- The variational approximation `q` implements `logpdf`. - `logpdf(q, η)` must be differentiable by the selected AD framework. # References From 8461b43c821980bd433631fd3ce9e369db56b01c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 8 Dec 2023 03:12:52 -0500 Subject: [PATCH 205/206] fix `reparam_with_entropy` specialization for bijectors --- ext/AdvancedVIBijectorsExt.jl | 32 +++++++++++++++++++----------- src/objectives/elbo/repgradelbo.jl | 19 +++++++++++------- src/utils.jl | 10 ++++++++++ 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 5d9dc774..1b200ac5 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -13,25 +13,33 @@ end function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, - n_samples::Int, q ::Bijectors.TransformedDistribution, q_stop ::Bijectors.TransformedDistribution, - ent_est + n_samples::Int, + ent_est ::AdvancedVI.AbstractEntropyEstimator ) - transform = q.transform - q_base = q.dist - q_base_stop = q_stop.dist - ∑logabsdetjac = 0.0 - base_samples = rand(rng, q_base, n_samples) - samples = mapreduce(hcat, eachcol(base_samples)) do base_sample - sample, logabsdetjac = with_logabsdet_jacobian(transform, base_sample) - ∑logabsdetjac += logabsdetjac - sample + transform = q.transform + q_base = q.dist + q_base_stop = q_stop.dist + base_samples = rand(rng, q_base, n_samples) + it = AdvancedVI.eachsample(base_samples) + sample_init = first(it) + + samples_and_logjac = mapreduce( + AdvancedVI.catsamples_and_acc, + Iterators.drop(it, 1); + init=with_logabsdet_jacobian(transform, sample_init) + ) do sample + with_logabsdet_jacobian(transform, sample) end + samples = first(samples_and_logjac) + logjac = last(samples_and_logjac) + entropy_base = AdvancedVI.estimate_entropy_maybe_stl( ent_est, base_samples, q_base, q_base_stop ) - entropy = entropy_base + ∑logabsdetjac/n_samples + + entropy = entropy_base + logjac/n_samples samples, entropy end end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 28bd681f..04f35320 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -42,8 +42,13 @@ RepGradELBO( entropy ::AbstractEntropyEstimator = ClosedFormEntropy() ) = RepGradELBO(entropy, n_samples) -Base.show(io::IO, obj::RepGradELBO) = - print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))") +function Base.show(io::IO, obj::RepGradELBO) + print(io, "RepGradELBO(entropy=") + print(io, obj.entropy) + print(io, ", n_samples=") + print(io, obj.n_samples) + print(io, ")") +end function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) @@ -55,15 +60,15 @@ function estimate_energy_with_samples(prob, samples) end """ - reparam_with_entropy(rng, n_samples, q, q_stop, ent_est) + reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) Draw `n_samples` from `q` and compute its entropy. # Arguments - `rng::Random.AbstractRNG`: Random number generator. -- `n_samples::Int`: Number of Monte Carlo samples - `q`: Variational approximation. - `q_stop`: `q` but with its gradient stopped. +- `n_samples::Int`: Number of Monte Carlo samples - `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) # Returns @@ -71,7 +76,7 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, n_samples::Int, q, q_stop, ent_est + rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) @@ -85,7 +90,7 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, n_samples, q, q, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) energy + entropy end @@ -106,7 +111,7 @@ function estimate_gradient!( q_stop = restructure(λ) function f(λ′) q = restructure(λ′) - samples, entropy = reparam_with_entropy(rng, obj.n_samples, q, q_stop, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) elbo = energy + entropy -elbo diff --git a/src/utils.jl b/src/utils.jl index 76637fa3..8e67ff1a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,3 +24,13 @@ end eachsample(samples::AbstractMatrix) = eachcol(samples) eachsample(samples::AbstractVector) = samples + +function catsamples_and_acc( + state_curr::Tuple{<:AbstractArray, <:Real}, + state_new ::Tuple{<:AbstractVector, <:Real} +) + x = hcat(first(state_curr), first(state_new)) + ∑y = last(state_curr) + last(state_new) + return (x, ∑y) +end + From bd925cce08473c9f5698beac05d810c146ecc56c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 8 Dec 2023 03:19:58 -0500 Subject: [PATCH 206/206] enable Zygote for non-bijector tests --- test/inference/repgradelbo_distributionsad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 29cb2d83..b6db22a6 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -16,7 +16,7 @@ using Test (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), + :Zygote => AutoZygote(), #:Enzyme => AutoEnzyme(), )