Skip to content

Commit

Permalink
Add PositiveDefinite implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace committed Feb 10, 2024
1 parent e0e41bf commit 7f07352
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
37 changes: 35 additions & 2 deletions src/parameters_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ end
"""
positive_semidefinite(X::AbstractMatrix{<:Real})
Produce a parameter whose `value` is constrained to be a positive-semidefinite matrix. The argument `X` needs to
be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix).
Produce a parameter whose `value` is constrained to be a positive-semidefinite matrix. The
argument `X` needs to be a positive-definite matrix
(see https://en.wikipedia.org/wiki/Definite_matrix).
The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector.
Expand All @@ -57,6 +58,23 @@ function positive_semidefinite(X::AbstractMatrix{<:Real})
return PositiveSemiDefinite(tril_to_vec(cholesky(X).L))
end

"""
positive_definite(X::AbstractMatrix{<:Real}, ε = eps(T))
Produce a parameter whose `value` is constrained to be a strictly positive-semidefinite
matrix. The argument `X` minus `ε` times the identity needs to be a positive-definite matrix
(see https://en.wikipedia.org/wiki/Definite_matrix). The optional second argument `ε` must
be a positive real number.
The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector.
"""
function positive_definite(X::AbstractMatrix{T}, ε = eps(T)) where T <: Real
ε > 0 || throw(ArgumentError("ε is not positive. Use `positive_semidefinite` instead."))
_X = X - ε * I
isposdef(_X) || throw(ArgumentError("X-ε*I is not positive-definite for ε="))
return PositiveDefinite(tril_to_vec(cholesky(_X).L), ε)
end

struct PositiveSemiDefinite{TL<:AbstractVector{<:Real}} <: AbstractParameter
L::TL
end
Expand All @@ -73,6 +91,21 @@ function flatten(::Type{T}, X::PositiveSemiDefinite) where {T<:Real}
return v, unflatten_PositiveSemiDefinite
end

struct PositiveDefinite{TL<:AbstractVector{<:Real}, Tε<:Real} <: AbstractParameter
L::TL
ε::Tε
end

Base.:(==)(X::PositiveDefinite, Y::PositiveDefinite) = X.L == Y.L

Check warning on line 99 in src/parameters_matrix.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_matrix.jl#L99

Added line #L99 was not covered by tests

value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) + X.ε * I

function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real}
v, unflatten_v = flatten(T, X.L)
unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new), X.ε)
return v, unflatten_PositiveDefinite
end

# Convert a vector to lower-triangular matrix
function vec_to_tril(v::AbstractVector{T}) where {T}
n_vec = length(v)
Expand Down
20 changes: 20 additions & 0 deletions test/parameters_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,24 @@ using ParameterHandling: vec_to_tril, tril_to_vec
@test vec_to_tril(Δl) == tril(ΔL)
ChainRulesTestUtils.test_rrule(vec_to_tril, x)
end

@testset "positive_definite" begin
X_mat = ParameterHandling.A_At(rand(3, 3)) # Create a positive definite object
X = positive_definite(X_mat)
@test isposdef(value(X))
X.L .= 0 # zero the unconstrained value
@test isposdef(value(X))
@test_throws ArgumentError positive_definite(zeros(3, 3))
@test_throws ArgumentError positive_definite(X_mat, 0.)
test_parameter_interface(X)

x, re = flatten(X)
Δl = first(
Zygote.gradient(x) do x
X = re(x)
return logdet(value(X))
end,
)
ChainRulesTestUtils.test_rrule(vec_to_tril, x)
end
end

0 comments on commit 7f07352

Please sign in to comment.