From 48607c5c6a845049c0342e3d642b209af49093a6 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 7 Jun 2024 01:28:12 +0100 Subject: [PATCH 1/4] add indirection for update step, add projection for `LocationScale` --- src/AdvancedVI.jl | 27 ++++++++++++++++ src/families/location_scale.jl | 58 +++++++++++++++++++++++----------- src/optimize.jl | 4 ++- 3 files changed, 69 insertions(+), 20 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9fc986d3..7c7a1fc8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -37,6 +37,33 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di """ function value_and_gradient! end +# Update for gradient descent step +""" + update_variational_params!(family_type, opt_st, params, restructure, grad) + +Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`. + +This is a wrapper around `Optimisers.update!` to provide some indirection. +For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings. +Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino. +Instead, the return values should be used. + +# Arguments +- `family_type::Type`: Type of the variational family `typeof(restructure(params))`. +- `opt_st`: Optimizer state returned by `Optimisers.setup`. +- `params`: Current set of parameters to be updated. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +- `grad`: Gradient to be used by the update rule of `opt_st`. + +# Returns +- `opt_st`: Updated optimizer state. +- `params`: Updated parameters. +""" +function update_variational_params! end + +update_variational_params!(::Type, opt_st, params, restructure, grad) = + Optimisers.update!(opt_st, params, grad) + # estimators """ AbstractVariationalObjective diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index e60538a1..5c3b0eda 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -14,11 +14,21 @@ represented as follows: ``` """ struct MvLocationScale{ - S, D <: ContinuousDistribution, L + S, D <: ContinuousDistribution, L, E <: Real } <: ContinuousMultivariateDistribution - location::L - scale ::S - dist ::D + location ::L + scale ::S + dist ::D + scale_eps::E +end + +function MvLocationScale( + location ::AbstractVector{T}, + scale ::AbstractMatrix{T}, + dist ::ContinuousDistribution; + scale_eps::T = sqrt(eps(T)) +) where {T <: Real} + MvLocationScale(location, scale, dist, scale_eps) end Functors.@functor MvLocationScale (location, scale) @@ -57,17 +67,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) - n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale)) + n_dims*convert(eltype(location), entropy(dist)) + logdet(scale) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions.rand(q::MvLocationScale) @@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix. function FullRankGaussian( μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; - check_args::Bool = true + scale_eps::T = sqrt(eps(T)) ) where {T <: Real} - @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" - if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) - @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." - end + @assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base) + MvLocationScale(μ, L, q_base, scale_eps) end """ @@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix function MeanFieldGaussian( μ::AbstractVector{T}, L::Diagonal{T}; - check_args::Bool = true + scale_eps::T = sqrt(eps(T)), ) where {T <: Real} - @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" - if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) - @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." - end + @assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base) + MvLocationScale(μ, L, q_base, scale_eps) +end + +function update_variational_params!( + ::MvLocationScale, opt_st, params, restructure, grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = q.scale_eps + + # Project the scale matrix to the set of positive definite triangular matrices + diag_idx = diagind(q.scale) + @. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ) + + params, _ = Optimisers.destructure(q) + + opt_st, params end diff --git a/src/optimize.jl b/src/optimize.jl index 325de5a2..4eb6644a 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -80,7 +80,9 @@ function optimize( stat = merge(stat, stat′) grad = DiffResults.gradient(grad_buf) - opt_st, params = Optimisers.update!(opt_st, params, grad) + opt_st, params = update_variational_params!( + typeof(q_init), opt_st, params, restructure, grad + ) if !isnothing(callback) stat′ = callback( From a54c7fc96adeff666523d92e62d4cd609b1e914a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 7 Jun 2024 02:13:34 +0100 Subject: [PATCH 2/4] add projection for `Bijectors` with `MvLocationScale` --- ext/AdvancedVIBijectorsExt.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 29a877fd..4a88d6fb 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -4,13 +4,37 @@ module AdvancedVIBijectorsExt if isdefined(Base, :get_extension) using AdvancedVI using Bijectors + using LinearAlgebra + using Optimisers using Random else using ..AdvancedVI using ..Bijectors + using ..LinearAlgebra + using ..Optimisers using ..Random end +function AdvancedVI.update_variational_params!( + ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, + opt_st, + params, + restructure, + grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = q.dist.scale_eps + + # Project the scale matrix to the set of positive definite triangular matrices + diag_idx = diagind(q.dist.scale) + @. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ) + + params, _ = Optimisers.destructure(q) + + opt_st, params +end + function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, From 538bafe1722ac4852541bc9786c1fc7f275c6460 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 12 Jun 2024 23:53:50 +0100 Subject: [PATCH 3/4] fix signature for projection of `LocationScale` families --- src/families/location_scale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 5c3b0eda..1e96576f 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -168,7 +168,7 @@ function MeanFieldGaussian( end function update_variational_params!( - ::MvLocationScale, opt_st, params, restructure, grad + ::Type{<:MvLocationScale}, opt_st, params, restructure, grad ) opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) From b2b09bb82b485fa127c27ae0e94eb02077e7c831 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 00:30:40 +0100 Subject: [PATCH 4/4] add tests for scale projection operators, fix bug for meanfield --- src/families/location_scale.jl | 4 ++-- test/interface/location_scale.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 1e96576f..e27a875b 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -46,14 +46,14 @@ function (re::RestructureMeanField)(flat::AbstractVector) n_dims = div(length(flat), 2) location = first(flat, n_dims) scale = Diagonal(last(flat, n_dims)) - MvLocationScale(location, scale, re.q.dist) + MvLocationScale(location, scale, re.q.dist, re.q.scale_eps) end function Optimisers.destructure( q::MvLocationScale{<:Diagonal, D, L} ) where {D, L} @unpack location, scale, dist = q - flat = vcat(location, diag(scale)) + flat = vcat(location, diag(scale)) flat, RestructureMeanField(q) end # end diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index 6670f5c2..94e51fa6 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -138,3 +138,35 @@ @test q == re(λ) end end + +@testset "scale positive definite projection" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64], + bijector = [nothing, :identity] + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + q = if covtype == :fullrank + L = LowerTriangular(Matrix{realtype}(I,d,d)) + FullRankGaussian(μ, L; scale_eps=ϵ) + elseif covtype == :meanfield + L = Diagonal(ones(realtype, d)) + MeanFieldGaussian(μ, L; scale_eps=ϵ) + end + q_trans = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + g = deepcopy(q) + + λ, re = Optimisers.destructure(q) + grad, _ = Optimisers.destructure(g) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) + @test all(diag(var(q′)) .≥ ϵ^2) + end +end