diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index 59d5dc96..ccc69a97 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -218,7 +218,8 @@ _, _, stats_cfe, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = ProjectScale(Optimisers.Adam(3e-3)), + optimizer = Optimisers.Adam(3e-3), + operator = ClipScale(), callback = callback, ); @@ -229,7 +230,8 @@ _, _, stats_stl, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = ProjectScale(Optimisers.Adam(3e-3)), + optimizer = Optimisers.Adam(3e-3), + operator = ClipScale(), callback = callback, ); @@ -316,7 +318,8 @@ _, _, stats_qmc, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = ProjectScale(Optimisers.Adam(3e-3)), + optimizer = Optimisers.Adam(3e-3), + operator = ClipScale(), callback = callback, ); diff --git a/docs/src/examples.md b/docs/src/examples.md index 9ecd3d26..9c0292b8 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -117,13 +117,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize( n_max_iter; show_progress=false, adtype=AutoForwardDiff(), - optimizer=ProjectScale(Optimisers.Adam(1e-3)), + optimizer=Optimisers.Adam(1e-3), + operator=ClipScale(), ); nothing ``` -`ProjectScale` is a wrapper around an optimization rule such that the variational approximation stays within a stable region of the variational family. -For more information see [this section](@ref projectscale). +`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family. +For more information see [this section](@ref clipscale). `q_avg_trans` is the final output of the optimization procedure. If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate. diff --git a/docs/src/families.md b/docs/src/families.md index 1c6f6472..e270acad 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -56,16 +56,6 @@ FullRankGaussian MeanFieldGaussian ``` -### [Scale Projection Operator](@id projectscale) - -For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. -To ensure this, we provide the following wrapper around optimization rule: - -```@docs -ProjectScale -``` - -[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. ### Gaussian Variational Families ```julia diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 05fe035d..af2b99fd 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -26,3 +26,19 @@ PolynomialAveraging [^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. [^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. [^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. + +## Operators + +Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update. +For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `optimize`. + +### `ClipScale` (@id clipscale) + +For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. +To ensure this, we provide the following projection operator: + +```@docs +ClipScale +``` + +[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index acadbdf9..6255392b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -232,6 +232,11 @@ Apply operator `op` on the variational parameters `params`. For instance, `op` c """ function operate end +""" + IdentityOperator() + +Identity operator. +""" struct IdentityOperator <: AbstractOperator end operate(::IdentityOperator, family, params, restructure) = params diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index c2369df2..a51bc292 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -2,7 +2,7 @@ """ ClipScale(ϵ = 1e-5) -Apply a projection ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`. +Projection operator ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`. `ClipScale` also supports by operating on `MvLocationScale` and `MvLocationScaleLowRank` wrapped by a `Bijectors.TransformedDistribution` object. """ Optimisers.@def struct ClipScale <: AbstractOperator diff --git a/src/optimize.jl b/src/optimize.jl index 272c80d2..93e2afff 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -17,6 +17,7 @@ This requires the variational approximation to be marked as a functor through `F - `adtype::ADtypes.AbstractADType`: Automatic differentiation backend. - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) - `averager::AbstractAverager` : Parameter averaging strategy. (Default: `NoAveraging()`) +- `operator::AbstractOperator` : Operator applied to the parameters after each optimization step. (Default: `IdentityOperator()`) - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) - `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.)