Skip to content

Commit

Permalink
fix base distribution standardization bug in LocationScale
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Sep 5, 2024
1 parent 5c04d50 commit ba293e5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 33 deletions.
11 changes: 8 additions & 3 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,21 @@ function Distributions._rand!(
return x .+= location
end

Distributions.mean(q::MvLocationScale) = q.location
function Distributions.mean(q::MvLocationScale)
@unpack location, scale = q
return location + scale * Fill(mean(q.dist), length(location))
end

function Distributions.var(q::MvLocationScale)
C = q.scale
return Diagonal(C * C')
σ2 = var(q.dist)
return σ2 * diag(C * C')
end

function Distributions.cov(q::MvLocationScale)
C = q.scale
return Hermitian(C * C')
σ2 = var(q.dist)
return σ2 * Hermitian(C * C')
end

"""
Expand Down
67 changes: 37 additions & 30 deletions test/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@

@testset "interface LocationScale" begin
@testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian, :studentt],
@testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in
[:gaussian, :gaussian_nonstd],
covtype in [:meanfield, :fullrank],
realtype in [Float32, Float64]

n_dims = 10
n_montecarlo = 1000_000

μ = randn(realtype, n_dims)
L = if covtype == :fullrank
location = randn(realtype, n_dims)
scale = if covtype == :fullrank
LowerTriangular(tril(I + ones(realtype, n_dims, n_dims) / 2))
else
Diagonal(ones(realtype, n_dims))
end
Σ = L * L'

q = if covtype == :fullrank && basedist == :gaussian
FullRankGaussian(μ, L)
FullRankGaussian(location, scale)
elseif covtype == :meanfield && basedist == :gaussian
MeanFieldGaussian(μ, L)
elseif covtype == :fullrank && basedist == :studentt
MvLocationScale(μ, L, TDist(realtype(10.0)))
elseif covtype == :meanfield && basedist == :studentt
MvLocationScale(μ, L, TDist(realtype(10.0)))
MeanFieldGaussian(location, scale)
elseif covtype == :fullrank && basedist == :gaussian_nonstd
MvLocationScale(location, scale, Normal(realtype(3), realtype(3)))
elseif covtype == :meanfield && basedist == :gaussian_nonstd
MvLocationScale(location, scale, Normal(realtype(3), realtype(3)))
end

q_true = if basedist == :gaussian
MvNormal(μ, Σ)
elseif basedist == :studentt
MvTDist(realtype(10.0), μ, Matrix(Σ))
MvNormal(location, scale * scale')
elseif basedist == :gaussian_nonstd
MvNormal(location + scale * fill(3, n_dims), 9 * scale * scale')
end

println(q)
@testset "eltype" begin
@test eltype(q) == realtype
end
Expand All @@ -54,27 +53,29 @@
@testset "statistics" begin
@testset "mean" begin
@test eltype(mean(q)) == realtype
@test mean(q) == μ
@test mean(q) mean(q_true)
end
@testset "var" begin
@test eltype(var(q)) == realtype
@test var(q) Diagonal)
@test var(q) var(q_true)
end
@testset "cov" begin
@test eltype(cov(q)) == realtype
@test cov(q) Σ
@test cov(q) cov(q_true)
end
end

@testset "sampling" begin
@testset "rand" begin
z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo)
@test eltype(z_samples) == realtype
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

z_sample_ref = rand(StableRNG(1), q)
@test z_sample_ref == rand(StableRNG(1), q)
Expand All @@ -83,11 +84,13 @@
@testset "rand batch" begin
z_samples = rand(q, n_montecarlo)
@test eltype(z_samples) == realtype
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

samples_ref = rand(StableRNG(1), q, n_montecarlo)
@test samples_ref == rand(StableRNG(1), q, n_montecarlo)
Expand All @@ -102,11 +105,13 @@
z_samples = mapreduce(first, hcat, res)
z_samples_ret = mapreduce(last, hcat, res)
@test z_samples == z_samples_ret
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

z_sample_ref = Array{realtype}(undef, n_dims)
rand!(StableRNG(1), q, z_sample_ref)
Expand All @@ -120,11 +125,13 @@
z_samples = Array{realtype}(undef, n_dims, n_montecarlo)
z_samples_ret = rand!(q, z_samples)
@test z_samples == z_samples_ret
@test dropdims(mean(z_samples; dims=2); dims=2) μ rtol = realtype(1e-2)
@test dropdims(var(z_samples; dims=2); dims=2) diag(Σ) rtol = realtype(
@test dropdims(mean(z_samples; dims=2); dims=2) mean(q_true) rtol = realtype(
1e-2
)
@test dropdims(var(z_samples; dims=2); dims=2) var(q_true) rtol = realtype(
1e-2
)
@test cov(z_samples; dims=2) Σ rtol = realtype(1e-2)
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo)
rand!(StableRNG(1), q, z_samples_ref)
Expand Down Expand Up @@ -164,7 +171,7 @@
opt_st = Optimisers.setup(Descent(one(realtype)), λ)
_, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad)
q′ = re(λ′)
@test all(diag(var(q′)) .≥ ϵ^2)
@test all(var(q′) .≥ ϵ^2)
end
end

Expand Down

0 comments on commit ba293e5

Please sign in to comment.