Skip to content

Commit

Permalink
fix base distribution standardization bug in LocationScaleLowRank
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Sep 5, 2024
1 parent ba293e5 commit 426d943
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 44 deletions.
53 changes: 34 additions & 19 deletions src/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,29 @@ function StatsBase.entropy(q::MvLocationScaleLowRank)
return n_dims * convert(eltype(location), entropy(dist)) + logdetΣ / 2
end

function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Real})
function Distributions.logpdf(
q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false
)
@unpack location, scale_diag, scale_factors, dist = q
#
## More efficient O(kd^2) but non-differentiable version:
#
# Σchol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag))))
# n_factors = size(scale_factors, 2)
# for k in 1:n_factors
# factor = scale_factors[:,k]
# lowrankupdate!(Σchol, factor)
# end

Σ = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors'
Σchol = cholesky(Σ)
return sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L)
μ_base = mean(dist)
n_dims = length(location)

scale2chol = if non_differntiable
# Fast O(kd^2) path (not supported by most current AD frameworks):
scale2chol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag))))
n_factors = size(scale_factors, 2)
for k in 1:n_factors
factor = scale_factors[:, k] # copy necessary due to in-place mutation
lowrankupdate!(scale2chol, factor)
end
scale2chol
else
# Slow but differentiable O(d^3) path
scale2 = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors'
cholesky(scale2)
end
z_std = z - mean(q) + scale2chol.L * Fill(μ_base, n_dims)
return sum(Base.Fix1(logpdf, dist), scale2chol.L \ z_std) - logdet(scale2chol.L)
end

function Distributions.rand(q::MvLocationScaleLowRank)
Expand Down Expand Up @@ -111,18 +119,25 @@ function Distributions._rand!(
return x .+= location
end

Distributions.mean(q::MvLocationScaleLowRank) = q.location
function Distributions.mean(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors = q
μ = mean(q.dist)
return location +
scale_diag .* Fill(μ, length(scale_diag)) +
scale_factors * Fill(μ, size(scale_factors, 2))
end

function Distributions.var(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
return Diagonal(
scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1]
)
σ2 = var(q.dist)
return σ2 *
(scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1])
end

function Distributions.cov(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
return Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors'
σ2 = var(q.dist)
return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors')
end

function update_variational_params!(
Expand Down
74 changes: 49 additions & 25 deletions test/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@

@testset "interface LocationScaleLowRank" begin
@testset "$(basedist) rank=$(rank) $(realtype)" for basedist in [:gaussian],
rank in [1, 2],
@testset "$(basedist) rank=$(rank) $(realtype)" for basedist in
[:gaussian, :gaussian_nonstd],
n_rank in [1, 2],
realtype in [Float32, Float64]

n_dims = 10
n_montecarlo = 1000_000

μ = randn(realtype, n_dims)
D = ones(realtype, n_dims)
U = randn(realtype, n_dims, rank)
Σ = Diagonal(D .^ 2) + U * U'
location = randn(realtype, n_dims)
scale_diag = ones(realtype, n_dims)
scale_factors = randn(realtype, n_dims, n_rank)

q = if basedist == :gaussian
LowRankGaussian(μ, D, U)
LowRankGaussian(location, scale_diag, scale_factors)
elseif basedist == :gaussian_nonstd
MvLocationScaleLowRank(
location, scale_diag, scale_factors, Normal(realtype(3), realtype(3))
)
end

q_true = if basedist == :gaussian
μ = location
Σ = Diagonal(scale_diag .^ 2) + scale_factors * scale_factors'
MvNormal(location, Σ)
elseif basedist == :gaussian_nonstd
μ = location + scale_diag .* fill(3, n_dims) + scale_factors * fill(3, n_rank)
Σ = 3^2 * (Diagonal(scale_diag .^ 2) + scale_factors * scale_factors')
MvNormal(μ, Σ)
end

Expand All @@ -27,6 +38,11 @@
z = rand(q)
@test logpdf(q, z) logpdf(q_true, z) rtol = realtype(1e-2)
@test eltype(logpdf(q, z)) == realtype

@test logpdf(q, z; non_differntiable=true) logpdf(q_true, z) rtol = realtype(
1e-2
)
@test eltype(logpdf(q, z; non_differntiable=true)) == realtype
end

@testset "entropy" begin
Expand All @@ -41,27 +57,29 @@
@testset "statistics" begin
@testset "mean" begin
@test eltype(mean(q)) == realtype
@test mean(q) == μ
@test mean(q) == mean(q_true)
end
@testset "var" begin
@test eltype(var(q)) == realtype
@test var(q) Diagonal)
@test var(q) var(q_true)
end
@testset "cov" begin
@test eltype(cov(q)) == realtype
@test cov(q) Σ
@test cov(q) cov(q_true)
end
end

@testset "sampling" begin
@testset "rand" begin
z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo)
@test eltype(z_samples) == realtype
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

z_sample_ref = rand(StableRNG(1), q)
@test z_sample_ref == rand(StableRNG(1), q)
Expand All @@ -70,11 +88,13 @@
@testset "rand batch" begin
z_samples = rand(q, n_montecarlo)
@test eltype(z_samples) == realtype
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

samples_ref = rand(StableRNG(1), q, n_montecarlo)
@test samples_ref == rand(StableRNG(1), q, n_montecarlo)
Expand All @@ -89,11 +109,13 @@
z_samples = mapreduce(first, hcat, res)
z_samples_ret = mapreduce(last, hcat, res)
@test z_samples == z_samples_ret
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

z_sample_ref = Array{realtype}(undef, n_dims)
rand!(StableRNG(1), q, z_sample_ref)
Expand All @@ -107,11 +129,13 @@
z_samples = Array{realtype}(undef, n_dims, n_montecarlo)
z_samples_ret = rand!(q, z_samples)
@test z_samples == z_samples_ret
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo)
rand!(StableRNG(1), q, z_samples_ref)
Expand All @@ -127,12 +151,12 @@
@testset "$(realtype) $(bijector)" for realtype in [Float32, Float64],
bijector in [nothing, :identity]

rank = 2
n_rank = 2
d = 5
μ = zeros(realtype, d)
ϵ = sqrt(realtype(0.5))
D = ones(realtype, d)
U = randn(realtype, d, rank)
U = randn(realtype, d, n_rank)
q = MvLocationScaleLowRank(
μ, D, U, Normal{realtype}(zero(realtype), one(realtype)); scale_eps=ϵ
)
Expand All @@ -148,7 +172,7 @@
opt_st = Optimisers.setup(Descent(one(realtype)), λ)
_, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad)
q′ = re(λ′)
@test all(diag(var(q′)) .≥ ϵ^2)
@test all(var(q′) .≥ ϵ^2)
end
end
end

0 comments on commit 426d943

Please sign in to comment.