From 7f073521e1c3a077384e7696007467d03e03fd18 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 10 Feb 2024 12:22:14 +0100 Subject: [PATCH] Add `PositiveDefinite` implementation and tests --- src/parameters_matrix.jl | 37 +++++++++++++++++++++++++++++++++++-- test/parameters_matrix.jl | 20 ++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/parameters_matrix.jl b/src/parameters_matrix.jl index 0949d14..f72f95c 100644 --- a/src/parameters_matrix.jl +++ b/src/parameters_matrix.jl @@ -41,8 +41,9 @@ end """ positive_semidefinite(X::AbstractMatrix{<:Real}) -Produce a parameter whose `value` is constrained to be a positive-semidefinite matrix. The argument `X` needs to -be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix). +Produce a parameter whose `value` is constrained to be a positive-semidefinite matrix. The +argument `X` needs to be a positive-definite matrix +(see https://en.wikipedia.org/wiki/Definite_matrix). The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector. @@ -57,6 +58,23 @@ function positive_semidefinite(X::AbstractMatrix{<:Real}) return PositiveSemiDefinite(tril_to_vec(cholesky(X).L)) end +""" + positive_definite(X::AbstractMatrix{<:Real}, ε = eps(T)) + +Produce a parameter whose `value` is constrained to be a strictly positive-semidefinite +matrix. The argument `X` minus `ε` times the identity needs to be a positive-definite matrix +(see https://en.wikipedia.org/wiki/Definite_matrix). The optional second argument `ε` must +be a positive real number. + +The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector. +""" +function positive_definite(X::AbstractMatrix{T}, ε = eps(T)) where T <: Real + ε > 0 || throw(ArgumentError("ε is not positive. Use `positive_semidefinite` instead.")) + _X = X - ε * I + isposdef(_X) || throw(ArgumentError("X-ε*I is not positive-definite for ε=$ε")) + return PositiveDefinite(tril_to_vec(cholesky(_X).L), ε) +end + struct PositiveSemiDefinite{TL<:AbstractVector{<:Real}} <: AbstractParameter L::TL end @@ -73,6 +91,21 @@ function flatten(::Type{T}, X::PositiveSemiDefinite) where {T<:Real} return v, unflatten_PositiveSemiDefinite end +struct PositiveDefinite{TL<:AbstractVector{<:Real}, Tε<:Real} <: AbstractParameter + L::TL + ε::Tε +end + +Base.:(==)(X::PositiveDefinite, Y::PositiveDefinite) = X.L == Y.L + +value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) + X.ε * I + +function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real} + v, unflatten_v = flatten(T, X.L) + unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new), X.ε) + return v, unflatten_PositiveDefinite +end + # Convert a vector to lower-triangular matrix function vec_to_tril(v::AbstractVector{T}) where {T} n_vec = length(v) diff --git a/test/parameters_matrix.jl b/test/parameters_matrix.jl index 0b90376..019631e 100644 --- a/test/parameters_matrix.jl +++ b/test/parameters_matrix.jl @@ -54,4 +54,24 @@ using ParameterHandling: vec_to_tril, tril_to_vec @test vec_to_tril(Δl) == tril(ΔL) ChainRulesTestUtils.test_rrule(vec_to_tril, x) end + + @testset "positive_definite" begin + X_mat = ParameterHandling.A_At(rand(3, 3)) # Create a positive definite object + X = positive_definite(X_mat) + @test isposdef(value(X)) + X.L .= 0 # zero the unconstrained value + @test isposdef(value(X)) + @test_throws ArgumentError positive_definite(zeros(3, 3)) + @test_throws ArgumentError positive_definite(X_mat, 0.) + test_parameter_interface(X) + + x, re = flatten(X) + Δl = first( + Zygote.gradient(x) do x + X = re(x) + return logdet(value(X)) + end, + ) + ChainRulesTestUtils.test_rrule(vec_to_tril, x) + end end