From 4eab1ac5cfd4a79e0d6d342f1f441c7aa226b0eb Mon Sep 17 00:00:00 2001 From: Arnau Quera-Bofarull Date: Mon, 30 Sep 2024 06:00:53 +0100 Subject: [PATCH] Add ScoreELBO objective (#72) * add score estimator with baseline variance reduction. * run formatter Co-authored-by: Kyurae Kim --------- Co-authored-by: Kyurae Kim --- README.md | 43 +++--- bench/README.md | 2 - bench/benchmarks.jl | 34 ++--- bench/normallognormal.jl | 36 ++--- bench/utils.jl | 14 +- docs/src/elbo/overview.md | 12 +- docs/src/general.md | 20 ++- docs/src/index.md | 5 +- src/AdvancedVI.jl | 4 +- src/objectives/elbo/entropy.jl | 7 + src/objectives/elbo/repgradelbo.jl | 7 - src/objectives/elbo/scoregradelbo.jl | 139 ++++++++++++++++++ src/optimize.jl | 1 - .../scoregradelbo_distributionsad.jl | 101 +++++++++++++ test/inference/scoregradelbo_locationscale.jl | 105 +++++++++++++ .../scoregradelbo_locationscale_bijectors.jl | 111 ++++++++++++++ test/interface/scoregradelbo.jl | 57 +++++++ test/runtests.jl | 4 + 18 files changed, 618 insertions(+), 84 deletions(-) create mode 100644 src/objectives/elbo/scoregradelbo.jl create mode 100644 test/inference/scoregradelbo_distributionsad.jl create mode 100644 test/inference/scoregradelbo_locationscale.jl create mode 100644 test/inference/scoregradelbo_locationscale_bijectors.jl create mode 100644 test/interface/scoregradelbo.jl diff --git a/README.md b/README.md index f3bb745f..9bd000c7 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ [![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl) # AdvancedVI.jl + [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational inference (VI) algorithms, which is a family of algorithms aiming for scalable approximate Bayesian inference by leveraging optimization. `AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. -The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. +The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. For example, integrating `Turing` with `AdvancedVI.ADVI` only involves converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. ## Examples @@ -21,7 +22,8 @@ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), \end{aligned} $$ -a `LogDensityProblem` can be implemented as +a `LogDensityProblem` can be implemented as + ```julia using LogDensityProblems using SimpleUnPack @@ -35,46 +37,50 @@ end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) (; μ_x, σ_x, μ_y, Σ_y) = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) + return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 + return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{0}() end ``` -Since the support of `x` is constrained to be positive and VI is best done in the unconstrained Euclidean space, we need to use a *bijector* to transform `x` into unconstrained Euclidean space. We will use the [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package for this purpose. +Since the support of `x` is constrained to be positive and VI is best done in the unconstrained Euclidean space, we need to use a *bijector* to transform `x` into unconstrained Euclidean space. We will use the [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package for this purpose. This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017]. + ```julia using Bijectors function Bijectors.bijector(model::NormalLogNormal) (; μ_x, σ_x, μ_y, Σ_y) = model - Bijectors.Stacked( + return Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) + [1:1, 2:(1 + length(μ_y))], + ) end ``` A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. Let us instantiate a random normal-log-normal model. + ```julia using LinearAlgebra n_dims = 10 -μ_x = randn() -σ_x = exp.(randn()) -μ_y = randn(n_dims) -σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) ``` We can perform VI with stochastic gradient descent (SGD) using reparameterization gradient estimates of the ELBO[^TL2014][^RMW2014][^KW2014] as follows: + ```julia using Optimisers using ADTypes, ForwardDiff @@ -82,7 +88,7 @@ using AdvancedVI # ELBO objective with the reparameterization gradient n_montecarlo = 10 -elbo = AdvancedVI.RepGradELBO(n_montecarlo) +elbo = AdvancedVI.RepGradELBO(n_montecarlo) # Mean-field Gaussian variational family d = LogDensityProblems.dimension(model) @@ -91,11 +97,10 @@ L = Diagonal(ones(d)) q = AdvancedVI.MeanFieldGaussian(μ, L) # Match support by applying the `model`'s inverse bijector -b = Bijectors.bijector(model) -binv = inverse(b) +b = Bijectors.bijector(model) +binv = inverse(b) q_transformed = Bijectors.TransformedDistribution(q, binv) - # Run inference max_iter = 10^3 q_avg, _, stats, _ = AdvancedVI.optimize( @@ -103,8 +108,8 @@ q_avg, _, stats, _ = AdvancedVI.optimize( elbo, q_transformed, max_iter; - adtype = ADTypes.AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) + adtype=ADTypes.AutoForwardDiff(), + optimizer=Optimisers.Adam(1e-3), ) # Evaluate final ELBO with 10^3 Monte Carlo samples diff --git a/bench/README.md b/bench/README.md index 5c5214d8..8a8f5163 100644 --- a/bench/README.md +++ b/bench/README.md @@ -1,7 +1,5 @@ - # AdvancedVI.jl Continuous Benchmarking This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`. The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main). The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl). - diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 8585fa8c..551e12b2 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -33,25 +33,21 @@ const SUITES = BenchmarkGroup() # n_montecarlo = 4, # ) -SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = - @benchmarkable normallognormal( - ; - fptype = Float64, - adtype = AutoReverseDiff(), - family = :meanfield, - objective = :RepGradELBO, - n_montecarlo = 4, - ) - -SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = - @benchmarkable normallognormal( - ; - fptype = Float64, - adtype = AutoForwardDiff(), - family = :meanfield, - objective = :RepGradELBO, - n_montecarlo = 4, - ) +SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(; + fptype=Float64, + adtype=AutoReverseDiff(), + family=:meanfield, + objective=:RepGradELBO, + n_montecarlo=4, +) + +SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(; + fptype=Float64, + adtype=AutoForwardDiff(), + family=:meanfield, + objective=:RepGradELBO, + n_montecarlo=4, +) BenchmarkTools.tune!(SUITES; verbose=true) results = BenchmarkTools.run(SUITES; verbose=true) diff --git a/bench/normallognormal.jl b/bench/normallognormal.jl index 15d5a5a0..075bf3dc 100644 --- a/bench/normallognormal.jl +++ b/bench/normallognormal.jl @@ -8,49 +8,49 @@ end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) (; μ_x, σ_x, μ_y, Σ_y) = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) + return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 + return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{0}() end function Bijectors.bijector(model::NormalLogNormal) (; μ_x, σ_x, μ_y, Σ_y) = model - Bijectors.Stacked( + return Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) + [1:1, 2:(1 + length(μ_y))], + ) end -function normallognormal(; fptype, adtype, family, objective, kwargs...) +function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...) n_dims = 10 - μ_x = fptype(5.0) - σ_x = fptype(0.3) - μ_y = Fill(fptype(5.0), n_dims) - σ_y = Fill(fptype(0.3), n_dims) - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + μ_x = fptype(5.0) + σ_x = fptype(0.3) + μ_y = Fill(fptype(5.0), n_dims) + σ_y = Fill(fptype(0.3), n_dims) + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) obj = variational_objective(objective; kwargs...) d = LogDensityProblems.dimension(model) - q = variational_standard_mvnormal(fptype, d, family) + q = variational_standard_mvnormal(fptype, d, family) - b = Bijectors.bijector(model) - binv = inverse(b) + b = Bijectors.bijector(model) + binv = inverse(b) q_transformed = Bijectors.TransformedDistribution(q, binv) - max_iter = 10^3 - AdvancedVI.optimize( + return AdvancedVI.optimize( model, obj, q_transformed, max_iter; adtype, - optimizer = Optimisers.Adam(fptype(1e-3)), - show_progress = false, + optimizer=Optimisers.Adam(fptype(1e-3)), + show_progress=false, ) end diff --git a/bench/utils.jl b/bench/utils.jl index 31c87c3d..d95741cd 100644 --- a/bench/utils.jl +++ b/bench/utils.jl @@ -1,13 +1,9 @@ function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol) if family == :meanfield - AdvancedVI.MeanFieldGaussian( - zeros(type, n_dims), Diagonal(ones(type, n_dims)) - ) + AdvancedVI.MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims))) else - AdvancedVI.FullRankGaussian( - zeros(type, n_dims), Matrix(type, I, n_dims, n_dims) - ) + AdvancedVI.FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims)) end end @@ -15,6 +11,10 @@ function variational_objective(objective::Symbol; kwargs...) if objective == :RepGradELBO AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]) elseif objective == :RepGradELBOSTL - AdvancedVI.RepGradELBO(kwargs[:n_montecarlo], entropy=StickingTheLandingEntropy()) + AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy()) + elseif objective == :ScoreGradELBO + throw("ScoreGradELBO not supported yet. Please use ScoreGradELBOSTL instead.") + elseif objective == :ScoreGradELBOSTL + AdvancedVI.ScoreGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy()) end end diff --git a/docs/src/elbo/overview.md b/docs/src/elbo/overview.md index 4afac4db..db9b598e 100644 --- a/docs/src/elbo/overview.md +++ b/docs/src/elbo/overview.md @@ -1,5 +1,5 @@ - # [Evidence Lower Bound Maximization](@id elbomax) + ## Introduction Evidence lower bound (ELBO) maximization[^JGJS1999] is a general family of algorithms that minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence between the target distribution ``\pi`` and a variational approximation ``q_{\lambda}``. @@ -8,15 +8,19 @@ More generally, they aim to solve the following problem: ```math \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right), ``` + where $$\mathcal{Q}$$ is some family of distributions, often called the variational family. Since the target distribution ``\pi`` is intractable in general, the KL divergence is also intractable. Instead, the ELBO maximization strategy maximizes a surrogate objective, the *ELBO*: + ```math \mathrm{ELBO}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right) + \mathbb{H}\left(q\right), ``` + which serves as a lower bound to the KL. The ELBO and its gradient can be readily estimated through various strategies. Overall, ELBO maximization algorithms aim to solve the problem: + ```math \mathrm{maximize}_{q \in \mathcal{Q}}\quad \mathrm{ELBO}\left(q\right). ``` @@ -24,13 +28,15 @@ Overall, ELBO maximization algorithms aim to solve the problem: Multiple ways to solve this problem exist, each leading to a different variational inference algorithm. ## Algorithms + Currently, `AdvancedVI` only provides the approach known as black-box variational inference (also known as Monte Carlo VI, Stochastic Gradient VI). (Introduced independently by two groups [^RGB2014][^TL2014] in 2014.) In particular, `AdvancedVI` focuses on the reparameterization gradient estimator[^TL2014][^RMW2014][^KW2014], which is generally superior compared to alternative strategies[^XQKS2019], discussed in the following section: -* [RepGradELBO](@ref repgradelbo) + + - [RepGradELBO](@ref repgradelbo) [^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. -[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. [^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. [^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. [^XQKS2019]: Xu, M., Quiroz, M., Kohn, R., & Sisson, S. A. (2019). Variance reduction properties of the reparameterization trick. In *The International Conference on Artificial Intelligence and Statistics. diff --git a/docs/src/general.md b/docs/src/general.md index 07a240e1..f4a8281d 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -1,13 +1,14 @@ - # [General Usage](@id general) Each VI algorithm provides the followings: -1. Variational families supported by each VI algorithm. -2. A variational objective corresponding to the VI algorithm. -Note that each variational family is subject to its own constraints. -Thus, please refer to the documentation of the variational inference algorithm of interest. + + 1. Variational families supported by each VI algorithm. + 2. A variational objective corresponding to the VI algorithm. + Note that each variational family is subject to its own constraints. + Thus, please refer to the documentation of the variational inference algorithm of interest. ## Optimizing a Variational Objective + After constructing a *variational objective* `objective` and initializing a *variational approximation*, one can optimize `objective` by calling `optimize`: ```@docs @@ -15,28 +16,35 @@ optimize ``` ## Estimating the Objective + In some cases, it is useful to directly estimate the objective value. This can be done by the following funciton: + ```@docs estimate_objective ``` !!! info - Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance. + + Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance. ## Advanced Usage + Each variational objective is a subtype of the following abstract type: + ```@docs AdvancedVI.AbstractVariationalObjective ``` Furthermore, `AdvancedVI` only interacts with each variational objective by querying gradient estimates. Therefore, to create a new custom objective to be optimized through `AdvancedVI`, it suffices to implement the following function: + ```@docs AdvancedVI.estimate_gradient! ``` If an objective needs to be stateful, one can implement the following function to inialize the state. + ```@docs AdvancedVI.init ``` diff --git a/docs/src/index.md b/docs/src/index.md index feb7adff..f177cd72 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,10 +5,13 @@ CurrentModule = AdvancedVI # AdvancedVI ## Introduction + [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. `AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. ## Provided Algorithms + `AdvancedVI` currently provides the following algorithm for evidence lower bound maximization: -- [Evidence Lower-Bound Maximization](@ref elbomax) + + - [Evidence Lower-Bound Maximization](@ref elbomax) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5402e075..4c3c39cc 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -170,10 +170,12 @@ Estimate the entropy of `q`. """ function estimate_entropy end -export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy +export RepGradELBO, + ScoreGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") +include("objectives/elbo/scoregradelbo.jl") # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 210b49ca..fa34022a 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -37,3 +37,10 @@ function estimate_entropy( -logpdf(q, mc_sample) end end + +function estimate_entropy_maybe_stl( + entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop +) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + return estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index e6f04ae8..b8bf63fa 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -45,13 +45,6 @@ function Base.show(io::IO, obj::RepGradELBO) return print(io, ")") end -function estimate_entropy_maybe_stl( - entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop -) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - return estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end - function estimate_energy_with_samples(prob, samples) return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end diff --git a/src/objectives/elbo/scoregradelbo.jl b/src/objectives/elbo/scoregradelbo.jl new file mode 100644 index 00000000..053c6b3f --- /dev/null +++ b/src/objectives/elbo/scoregradelbo.jl @@ -0,0 +1,139 @@ +""" + ScoreGradELBO(n_samples; kwargs...) + +Evidence lower-bound objective computed with score function gradients. +```math +\\begin{aligned} +\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right) +&\\= +\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ + \\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z) +\\right] ++ \\mathbb{H}\\left(q_{\\lambda}\\right), +\\end{aligned} +``` + +To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective. + +```math +\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ + \\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right) +\\right] +``` + +# Arguments +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`) +- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`) +- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`) + +# Requirements +- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`. +- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend. +- The target distribution and the variational approximation have the same support. + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. +""" +struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <: + AdvancedVI.AbstractVariationalObjective + entropy::EntropyEst + n_samples::Int + baseline_window_size::Int + baseline_history::Vector{Float64} +end + +function ScoreGradELBO( + n_samples::Int; + entropy::AbstractEntropyEstimator=ClosedFormEntropy(), + baseline_window_size::Int=10, + baseline_history::Vector{Float64}=Float64[], +) + return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history) +end + +function Base.show(io::IO, obj::ScoreGradELBO) + print(io, "ScoreGradELBO(entropy=") + print(io, obj.entropy) + print(io, ", n_samples=") + print(io, obj.n_samples) + print(io, ", baseline_window_size=") + print(io, obj.baseline_window_size) + return print(io, ")") +end + +function compute_control_variate_baseline(history, window_size) + if length(history) == 0 + return 1.0 + end + min_index = max(1, length(history) - window_size) + return mean(history[min_index:end]) +end + +function estimate_energy_with_samples( + prob, samples_stop, samples_logprob, samples_logprob_stop, baseline +) + fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop)) + fv_mean = mean(fv) + score_grad = mean(@. samples_logprob * (fv - baseline)) + score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline)) + return fv_mean + (score_grad - score_grad_stop) +end + +function estimate_objective( + rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples +) + samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy) + energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + return mean(energy) + entropy +end + +function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples) + return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +end + +function estimate_scoregradelbo_ad_forward(params′, aux) + @unpack rng, obj, problem, adtype, restructure, q_stop = aux + baseline = compute_control_variate_baseline( + obj.baseline_history, obj.baseline_window_size + ) + q = restructure_ad_forward(adtype, restructure, params′) + samples_stop = rand(rng, q_stop, obj.n_samples) + entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop) + samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop)) + samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop)) + energy = estimate_energy_with_samples( + problem, samples_stop, samples_logprob, samples_logprob_stop, baseline + ) + elbo = energy + entropy + return -elbo +end + +function AdvancedVI.estimate_gradient!( + rng::Random.AbstractRNG, + obj::ScoreGradELBO, + adtype::ADTypes.AbstractADType, + out::DiffResults.MutableDiffResult, + prob, + params, + restructure, + state, +) + q_stop = restructure(params) + aux = ( + rng=rng, + adtype=adtype, + obj=obj, + problem=prob, + restructure=restructure, + q_stop=q_stop, + ) + AdvancedVI.value_and_gradient!( + adtype, estimate_scoregradelbo_ad_forward, params, aux, out + ) + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + push!(obj.baseline_history, -nelbo) + return out, nothing, stat +end diff --git a/src/optimize.jl b/src/optimize.jl index eb462ff5..9a748907 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -73,7 +73,6 @@ function optimize( for t in 1:max_iter stat = (iteration=t,) - grad_buf, obj_st, stat′ = estimate_gradient!( rng, objective, diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl new file mode 100644 index 00000000..700dda6d --- /dev/null +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -0,0 +1,101 @@ + +AD_distributionsad = Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment + :Zygote => AutoZygote(), +) + +if @isdefined(Tapir) + AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) +end + +#if @isdefined(Enzyme) +# AD_distributionsad[:Enzyme] = AutoEnzyme() +#end + +@testset "inference RepGradELBO DistributionsAD" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in Dict(:Normal => normal_meanfield), + n_montecarlo in [1, 10], + (objname, objective) in Dict( + :ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), + :ScoreGradELBOStickingTheLanding => + ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + ), + (adbackname, adtype) in AD_distributionsad + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + + T = 1000 + η = 1e-5 + opt = Optimisers.Descent(realtype(η)) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + μ0 = zeros(realtype, n_dims) + L0 = Diagonal(ones(realtype, n_dims)) + q0 = TuringDiagMvNormal(μ0, diag(L0)) + + @testset "convergence" begin + Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = mean(q_avg) + L = sqrt(cov(q_avg)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + μ = mean(q_avg) + L = sqrt(cov(q_avg)) + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = mean(q_avg) + L_repl = sqrt(cov(q_avg)) + @test μ ≈ μ_repl rtol = 1e-5 + @test L ≈ L_repl rtol = 1e-5 + end + end +end diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl new file mode 100644 index 00000000..ef49713b --- /dev/null +++ b/test/inference/scoregradelbo_locationscale.jl @@ -0,0 +1,105 @@ + +AD_locationscale = Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), +) + +if @isdefined(Tapir) + AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) +end + +if @isdefined(Enzyme) + AD_locationscale[:Enzyme] = AutoEnzyme() +end + +@testset "inference ScoreGradELBO VILocationScale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), + n_montecarlo in [1, 10], + (objname, objective) in Dict( + :ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), + :ScoreGradELBOStickingTheLanding => + ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + ), + (adbackname, adtype) in AD_locationscale + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + + T = 1000 + η = 1e-5 + opt = Optimisers.Descent(realtype(η)) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + + @testset "convergence" begin + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = q_avg.location + L = q_avg.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ ≈ μ_repl rtol = 1e-3 + @test L ≈ L_repl rtol = 1e-3 + end + end +end diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl new file mode 100644 index 00000000..088130aa --- /dev/null +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -0,0 +1,111 @@ + +AD_locationscale_bijectors = Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), +) + +#if @isdefined(Tapir) +# AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) +#end + +if @isdefined(Enzyme) + AD_locationscale_bijectors[:Enzyme] = AutoEnzyme() +end + +@testset "inference ScoreGradELBO VILocationScale Bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:NormalLogNormalMeanField => normallognormal_meanfield), + n_montecarlo in [1, 10], + (objname, objective) in Dict( + #:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet. + :ScoreGradELBOStickingTheLanding => + ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), + ), + (adbackname, adtype) in AD_locationscale_bijectors + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats + + T = 1000 + η = 1e-5 + opt = Optimisers.Descent(realtype(η)) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + + q0_η = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + q0_z = Bijectors.transformed(q0_η, b⁻¹) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + @testset "convergence" begin + Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = q_avg.dist.location + L = q_avg.dist.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + μ = q_avg.dist.location + L = q_avg.dist.scale + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = q_avg.dist.location + L_repl = q_avg.dist.scale + @test μ ≈ μ_repl rtol = 1e-3 + @test L ≈ L_repl rtol = 1e-3 + end + end +end diff --git a/test/interface/scoregradelbo.jl b/test/interface/scoregradelbo.jl new file mode 100644 index 00000000..a800f744 --- /dev/null +++ b/test/interface/scoregradelbo.jl @@ -0,0 +1,57 @@ + +using Test + +@testset "interface ScoreGradELBO" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + + obj = ScoreGradELBO(10) + rng = StableRNG(seed) + elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = estimate_objective(obj, q0, model; n_samples=10^4) + @test elbo ≈ elbo_ref rtol = 0.2 + end +end + +@testset "interface ScoreGradELBO STL variance reduction" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + @testset for ad in [ + ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() + ] + q_true = MeanFieldGaussian( + Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) + ) + params, re = Optimisers.destructure(q_true) + obj = ScoreGradELBO( + 1000; entropy=StickingTheLandingEntropy(), baseline_history=[0.0] + ) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) + AdvancedVI.value_and_gradient!( + ad, AdvancedVI.estimate_scoregradelbo_ad_forward, params, aux, out + ) + value = DiffResults.value(out) + grad = DiffResults.gradient(out) + @test norm(grad) ≈ 0 atol = 10 # high tolerance required. + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 43958e8e..85bec3a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,7 @@ if GROUP == "All" || GROUP == "Interface" include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") + include("interface/scoregradelbo.jl") include("interface/rules.jl") include("interface/averaging.jl") end @@ -65,4 +66,7 @@ if GROUP == "All" || GROUP == "Inference" include("inference/repgradelbo_distributionsad.jl") include("inference/repgradelbo_locationscale.jl") include("inference/repgradelbo_locationscale_bijectors.jl") + include("inference/scoregradelbo_distributionsad.jl") + include("inference/scoregradelbo_locationscale.jl") + include("inference/scoregradelbo_locationscale_bijectors.jl") end