diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 29a877fd..4a88d6fb 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -4,13 +4,37 @@ module AdvancedVIBijectorsExt if isdefined(Base, :get_extension) using AdvancedVI using Bijectors + using LinearAlgebra + using Optimisers using Random else using ..AdvancedVI using ..Bijectors + using ..LinearAlgebra + using ..Optimisers using ..Random end +function AdvancedVI.update_variational_params!( + ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, + opt_st, + params, + restructure, + grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = q.dist.scale_eps + + # Project the scale matrix to the set of positive definite triangular matrices + diag_idx = diagind(q.dist.scale) + @. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ) + + params, _ = Optimisers.destructure(q) + + opt_st, params +end + function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9fc986d3..7c7a1fc8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -37,6 +37,33 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di """ function value_and_gradient! end +# Update for gradient descent step +""" + update_variational_params!(family_type, opt_st, params, restructure, grad) + +Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`. + +This is a wrapper around `Optimisers.update!` to provide some indirection. +For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings. +Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino. +Instead, the return values should be used. + +# Arguments +- `family_type::Type`: Type of the variational family `typeof(restructure(params))`. +- `opt_st`: Optimizer state returned by `Optimisers.setup`. +- `params`: Current set of parameters to be updated. +- `restructure`: Callable for restructuring the varitional distribution from `params`. +- `grad`: Gradient to be used by the update rule of `opt_st`. + +# Returns +- `opt_st`: Updated optimizer state. +- `params`: Updated parameters. +""" +function update_variational_params! end + +update_variational_params!(::Type, opt_st, params, restructure, grad) = + Optimisers.update!(opt_st, params, grad) + # estimators """ AbstractVariationalObjective diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index e60538a1..e27a875b 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -14,11 +14,21 @@ represented as follows: ``` """ struct MvLocationScale{ - S, D <: ContinuousDistribution, L + S, D <: ContinuousDistribution, L, E <: Real } <: ContinuousMultivariateDistribution - location::L - scale ::S - dist ::D + location ::L + scale ::S + dist ::D + scale_eps::E +end + +function MvLocationScale( + location ::AbstractVector{T}, + scale ::AbstractMatrix{T}, + dist ::ContinuousDistribution; + scale_eps::T = sqrt(eps(T)) +) where {T <: Real} + MvLocationScale(location, scale, dist, scale_eps) end Functors.@functor MvLocationScale (location, scale) @@ -36,14 +46,14 @@ function (re::RestructureMeanField)(flat::AbstractVector) n_dims = div(length(flat), 2) location = first(flat, n_dims) scale = Diagonal(last(flat, n_dims)) - MvLocationScale(location, scale, re.q.dist) + MvLocationScale(location, scale, re.q.dist, re.q.scale_eps) end function Optimisers.destructure( q::MvLocationScale{<:Diagonal, D, L} ) where {D, L} @unpack location, scale, dist = q - flat = vcat(location, diag(scale)) + flat = vcat(location, diag(scale)) flat, RestructureMeanField(q) end # end @@ -57,17 +67,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) - n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale)) + n_dims*convert(eltype(location), entropy(dist)) + logdet(scale) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions.rand(q::MvLocationScale) @@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix. function FullRankGaussian( μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; - check_args::Bool = true + scale_eps::T = sqrt(eps(T)) ) where {T <: Real} - @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" - if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) - @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." - end + @assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base) + MvLocationScale(μ, L, q_base, scale_eps) end """ @@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix function MeanFieldGaussian( μ::AbstractVector{T}, L::Diagonal{T}; - check_args::Bool = true + scale_eps::T = sqrt(eps(T)), ) where {T <: Real} - @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" - if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) - @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." - end + @assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base) + MvLocationScale(μ, L, q_base, scale_eps) +end + +function update_variational_params!( + ::Type{<:MvLocationScale}, opt_st, params, restructure, grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = q.scale_eps + + # Project the scale matrix to the set of positive definite triangular matrices + diag_idx = diagind(q.scale) + @. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ) + + params, _ = Optimisers.destructure(q) + + opt_st, params end diff --git a/src/optimize.jl b/src/optimize.jl index 325de5a2..4eb6644a 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -80,7 +80,9 @@ function optimize( stat = merge(stat, stat′) grad = DiffResults.gradient(grad_buf) - opt_st, params = Optimisers.update!(opt_st, params, grad) + opt_st, params = update_variational_params!( + typeof(q_init), opt_st, params, restructure, grad + ) if !isnothing(callback) stat′ = callback( diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index 6670f5c2..94e51fa6 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -138,3 +138,35 @@ @test q == re(λ) end end + +@testset "scale positive definite projection" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64], + bijector = [nothing, :identity] + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + q = if covtype == :fullrank + L = LowerTriangular(Matrix{realtype}(I,d,d)) + FullRankGaussian(μ, L; scale_eps=ϵ) + elseif covtype == :meanfield + L = Diagonal(ones(realtype, d)) + MeanFieldGaussian(μ, L; scale_eps=ϵ) + end + q_trans = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + g = deepcopy(q) + + λ, re = Optimisers.destructure(q) + grad, _ = Optimisers.destructure(g) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) + @test all(diag(var(q′)) .≥ ϵ^2) + end +end