Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "low-rank" variational families #76

Merged
merged 46 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
03563ea
rename location scale source file
Red-Portal Aug 3, 2024
5ab7286
revert renaming of location_scale file
Red-Portal Aug 3, 2024
3e0bf3d
add location-low-rank-scale family (except `entropy` and `logpdf`)
Red-Portal Aug 3, 2024
0bd6e5c
add feature complete `MvLocationScaleLowRank` with tests
Red-Portal Aug 5, 2024
34546e1
fix remove misleading comment
Red-Portal Aug 5, 2024
e030f2d
fix add missing test files
Red-Portal Aug 5, 2024
c7f36d6
fix broadcasting error on Julia 1.6
Red-Portal Aug 5, 2024
1bb3e3e
fix bug in sampling from `LocationScaleLowRank`
Red-Portal Aug 7, 2024
ddd2122
fix missing squared bug in `LocationScaleLowRank`
Red-Portal Aug 7, 2024
b24737f
add documentation for low-rank families
Red-Portal Aug 9, 2024
1d56953
add convenience constructors for `LocationScaleLowRank`
Red-Portal Aug 9, 2024
6752c6b
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 10, 2024
52568b5
fix mhauru's suggestions and run formatter
Red-Portal Aug 10, 2024
96eae86
run formatter
Red-Portal Aug 10, 2024
15556da
run formatter
Red-Portal Aug 10, 2024
f796154
fix bugs and improve comments in `MvLocationScale` and lowrank
Red-Portal Aug 11, 2024
6b1699c
promote families.md into a higher category
Red-Portal Aug 11, 2024
5187d76
add test for `MVLocationScale` with non-Gaussian
Red-Portal Aug 14, 2024
8821908
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 27, 2024
6dfc919
tighten compat bound for `Distributions`
Red-Portal Aug 27, 2024
c3ce393
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 4, 2024
5c04d50
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 5, 2024
ba293e5
fix base distribution standardization bug in `LocationScale`
Red-Portal Sep 5, 2024
426d943
fix base distribution standardization bug in `LocationScaleLowRank`
Red-Portal Sep 5, 2024
3cc9e80
format weird indentation in test `for` loops
Red-Portal Sep 5, 2024
0481dda
update docs add example for `LocationScaleLowRank`
Red-Portal Sep 5, 2024
8449402
fix docs warn about divergence when using `MvLocationScaleLowRank`
Red-Portal Sep 6, 2024
ff14c4c
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 9, 2024
e48f231
Merge branch 'master' into lowrank
yebai Sep 10, 2024
aa8feee
Merge branch 'master' into lowrank
yebai Sep 10, 2024
5149869
Merge branch 'master' into lowrank
yebai Sep 10, 2024
e196da6
Update Benchmark.yml
yebai Sep 10, 2024
e4bff67
disable more features for PRs from forks
yebai Sep 10, 2024
894a849
fix `LocationScale` interfaces to only allow univariate base dist
Red-Portal Sep 11, 2024
f1cabba
Merge branch 'lowrank' of github.com:Red-Portal/AdvancedVI.jl into lo…
Red-Portal Sep 11, 2024
ce6793c
fix test comparison operator for families
Red-Portal Sep 11, 2024
71aeb5a
fix test comparison operator for families
Red-Portal Sep 11, 2024
77ace2b
fix test comparison operator for families
Red-Portal Sep 11, 2024
641de39
fix test comparison operator for families
Red-Portal Sep 11, 2024
a58f209
fix test comparison operator for families
Red-Portal Sep 11, 2024
846b259
fix test comparison operator for families
Red-Portal Sep 11, 2024
1116f68
fix test comparison operator for families
Red-Portal Sep 11, 2024
42d730d
fix formatting
Red-Portal Sep 11, 2024
99d08c5
fix formatting
Red-Portal Sep 11, 2024
4a90c5d
fix scale lower bound to `1e-4`
Red-Portal Sep 12, 2024
c41709b
fix docstring for `LowRankGaussian`
Red-Portal Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ concurrency:
permissions:
contents: write
pull-requests: write
issues: write

jobs:
benchmark:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12.32"
FillArrays = "1.3"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ makedocs(;
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
],
"Variational Families" => "families.md",
"Optimization" => "optimization.md",
],
)

Expand Down
128 changes: 125 additions & 3 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path.
We provide multiple pre-packaged variational families that can be readily used.

## The `LocationScale` Family
## [The `LocationScale` Family](@id locscale)

The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as

Expand Down Expand Up @@ -38,6 +38,8 @@ where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.
Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``.
The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.

### API

!!! note

For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned.
Expand Down Expand Up @@ -128,14 +130,134 @@ and the entropy is given by the matrix determinant lemma as

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.

!!! note
```@setup lowrank
using ADTypes
using AdvancedVI
using Distributions
using LinearAlgebra
using LogDensityProblems
using Optimisers
using Plots
using ReverseDiff

struct Target{D}
dist::D
end

function LogDensityProblems.logdensity(model::Target, θ)
logpdf(model.dist, θ)
end

function LogDensityProblems.dimension(model::Target)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:Target})
return LogDensityProblems.LogDensityOrder{0}()
end

n_dims = 30
U_true = randn(n_dims, 3)
D_true = Diagonal(log.(1 .+ exp.(randn(n_dims))))
Σ_true = D_true + U_true*U_true'
Σsqrt_true = sqrt(Σ_true)
μ_true = randn(n_dims)
model = Target(MvNormal(μ_true, Σ_true));

d = LogDensityProblems.dimension(model);
μ = zeros(d);

L = Diagonal(ones(d));
q0_mf = MeanFieldGaussian(μ, L)

L = LowerTriangular(diagm(ones(d)));
q0_fr = FullRankGaussian(μ, L)

D = ones(n_dims)
U = zeros(n_dims, 3)
q0_lr = LowRankGaussian(μ, D, U)

obj = RepGradELBO(1);

max_iter = 10^4

function callback(; params, averaged_params, restructure, stat, kwargs...)
q = restructure(averaged_params)
μ, Σ = mean(q), cov(q)
(dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),)
end

_, _, stats_fr, _ = AdvancedVI.optimize(
model,
obj,
q0_fr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);

_, _, stats_mf, _ = AdvancedVI.optimize(
model,
obj,
q0_mf,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);

_, _, stats_lr, _ = AdvancedVI.optimize(
model,
obj,
q0_lr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);

t = [stat.iteration for stat in stats_fr]
dist_fr = [sqrt(stat.dist2) for stat in stats_fr]
dist_mf = [sqrt(stat.dist2) for stat in stats_mf]
dist_lr = [sqrt(stat.dist2) for stat in stats_lr]
plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
savefig("lowrank_family_wasserstein.svg")
nothing
```

Consider a 30-dimensional Gaussian with a diagonal plus low-rank covariance structure, where the true rank is 3.
Then, we can compare the convergence speed of `LowRankGaussian` versus `FullRankGaussian`:

![](lowrank_family_wasserstein.svg)

As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`.
While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster.

!!! info
`MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness.
Other optimisers may experience divergences.

`logpdf` for `LocationScaleLowRank` is unfortunately not computationally efficient and has the same time complexity as `LocationScale` with a full-rank scale.

### API

```@docs
MvLocationScaleLowRank
```

The `logpdf` of `MvLocationScaleLowRank` has an optional argument `non_differentiable::Bool` (default: `false`).
If set as `true`, a more efficient ``O\left(r d^2\right)`` implementation is used to evaluate the density.
This, however, is not differentiable under most AD frameworks due to the use of Cholesky `lowrankupdate`.
The default value is `false`, which uses a ``O\left(d^3\right)`` implementation, is differentiable and therefore compatible with the `StickingTheLandingEntropy` estimator.

The following is a specialized constructor for convenience:

```@docs
Expand Down
48 changes: 47 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,53 @@ export MvLocationScaleLowRank, LowRankGaussian

include("families/location_scale_low_rank.jl")

# Optimization Routine
# Optimization Rules

include("optimization/rules.jl")

export DoWG, DoG, COCOB

# Output averaging strategy

abstract type AbstractAverager end

"""
init(avg, params)

Initialize the state of the averaging strategy `avg` with the initial parameters `params`.

# Arguments
- `avg::AbstractAverager`: Averaging strategy.
- `params`: Initial variational parameters.
"""
init(::AbstractAverager, ::Any) = nothing

"""
apply(avg, avg_st, params)

Apply averaging strategy `avg` on `params` given the state `avg_st`.

# Arguments
- `avg::AbstractAverager`: Averaging strategy.
- `avg_st`: Previous state of the averaging strategy.
- `params`: Initial variational parameters.
"""
function apply(::AbstractAverager, ::Any, ::Any) end

"""
value(avg, avg_st)

Compute the output of the averaging strategy `avg` from the state `avg_st`.

# Arguments
- `avg::AbstractAverager`: Averaging strategy.
- `avg_st`: Previous state of the averaging strategy.
"""
function value(::AbstractAverager, ::Any) end

include("optimization/averaging.jl")

export NoAveraging, PolynomialAveraging

function optimize end

Expand Down
13 changes: 9 additions & 4 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Functors.@functor MvLocationScale (location, scale)
# is very inefficient.
# begin
struct RestructureMeanField{S<:Diagonal,D,L,E}
q::MvLocationScale{S,D,L,E}
model::MvLocationScale{S,D,L,E}
end

function (re::RestructureMeanField)(flat::AbstractVector)
Expand Down Expand Up @@ -113,16 +113,21 @@ function Distributions._rand!(
return x .+= location
end

Distributions.mean(q::MvLocationScale) = q.location
function Distributions.mean(q::MvLocationScale)
@unpack location, scale = q
return location + scale * Fill(mean(q.dist), length(location))
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end

function Distributions.var(q::MvLocationScale)
C = q.scale
return Diagonal(C * C')
σ2 = var(q.dist)
return σ2 * diag(C * C')
end

function Distributions.cov(q::MvLocationScale)
C = q.scale
return Hermitian(C * C')
σ2 = var(q.dist)
return σ2 * Hermitian(C * C')
Comment on lines +129 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the theory here well, but is there a reason why this involves var(q.dist) rather than cov(q.dist)? I could have imagined it being something like C * cov(q.dist) * C', though that's just a not-very-educated guess.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I was thinking that q.dist was constrained to be a univariate distribution, which would make all of this valid, but seems like I have to use ContinuousUnivariateDistribution for that. Let me fix this later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, yeah, this makes sense for univariate. Is there a reason you want to restrict to q.dist being univariate? Just less of a headache to implement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it thought to be the easiest way to force people to provide a standardized isotropic distribution. We're not quite forcing it to be standardized, but at least this guarantees it is isotropic.

end

"""
Expand Down
59 changes: 37 additions & 22 deletions src/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ represented as follows:
```julia
d = length(location)
r = size(scale_factors, 2)
u_d = rand(dist, d)
u_f = rand(dist, r)
z = scale_diag.*u_d + scale_factors*u_f + location
u_diag = rand(dist, d)
u_factors = rand(dist, r)
z = scale_diag.*u_diag + scale_factors*u_factors + location
```

`scale_eps` sets a constraint on the smallest value of `scale_diag` to be enforced during optimization.
Expand Down Expand Up @@ -60,21 +60,29 @@ function StatsBase.entropy(q::MvLocationScaleLowRank)
return n_dims * convert(eltype(location), entropy(dist)) + logdetΣ / 2
end

function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Real})
function Distributions.logpdf(
q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false
)
@unpack location, scale_diag, scale_factors, dist = q
#
## More efficient O(kd^2) but non-differentiable version:
#
# Σchol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag))))
# n_factors = size(scale_factors, 2)
# for k in 1:n_factors
# factor = scale_factors[:,k]
# lowrankupdate!(Σchol, factor)
# end

Σ = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors'
Σchol = cholesky(Σ)
return sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L)
μ_base = mean(dist)
n_dims = length(location)

scale2chol = if non_differntiable
# Fast O(kd^2) path (not supported by most current AD frameworks):
scale2chol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag))))
n_factors = size(scale_factors, 2)
for k in 1:n_factors
factor = scale_factors[:, k] # copy necessary due to in-place mutation
lowrankupdate!(scale2chol, factor)
end
scale2chol
else
# Slow but differentiable O(d^3) path
scale2 = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors'
cholesky(scale2)
end
z_std = z - mean(q) + scale2chol.L * Fill(μ_base, n_dims)
return sum(Base.Fix1(logpdf, dist), scale2chol.L \ z_std) - logdet(scale2chol.L)
end

function Distributions.rand(q::MvLocationScaleLowRank)
Expand Down Expand Up @@ -111,18 +119,25 @@ function Distributions._rand!(
return x .+= location
end

Distributions.mean(q::MvLocationScaleLowRank) = q.location
function Distributions.mean(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors = q
μ = mean(q.dist)
return location +
scale_diag .* Fill(μ, length(scale_diag)) +
scale_factors * Fill(μ, size(scale_factors, 2))
end

function Distributions.var(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
return Diagonal(
scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1]
)
σ2 = var(q.dist)
return σ2 *
(scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1])
end

function Distributions.cov(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
return Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors'
σ2 = var(q.dist)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors')
end

function update_variational_params!(
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ADTypes = "0.2.1, 1"
Bijectors = "0.13"
DiffResults = "1.0"
Distributions = "0.25.100"
Distributions = "0.25.111"
DistributionsAD = "0.6.45"
Enzyme = "0.12.32"
FillArrays = "1.6.1"
Expand Down
Loading
Loading