diff --git a/src/ParameterHandling.jl b/src/ParameterHandling.jl index 295f0ee..b3c1ddc 100644 --- a/src/ParameterHandling.jl +++ b/src/ParameterHandling.jl @@ -11,7 +11,10 @@ export flatten, value_flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite include("flatten.jl") -include("parameters.jl") +include("parameters_base.jl") +include("parameters_meta.jl") +include("parameters_scalar.jl") +include("parameters_matrix.jl") include("test_utils.jl") diff --git a/src/parameters.jl b/src/parameters.jl deleted file mode 100644 index 944ebe2..0000000 --- a/src/parameters.jl +++ /dev/null @@ -1,247 +0,0 @@ -abstract type AbstractParameter end - -""" - value(x) - -Return the "value" of an object. -For `AbstractParameter`s this typically applies some transformation to some data -contained in the parameter, and returns a plain data type. -It might, for example, return a transformation of some internal data, the result of which -is guaranteed to satisfy some constraint. -""" -value(x) - -# Various basic `value` definitions. -value(x::Number) = x -value(x::AbstractArray{<:Number}) = x -value(x::AbstractArray) = map(value, x) -value(x::Tuple) = map(value, x) -value(x::NamedTuple) = map(value, x) -value(x::Dict) = Dict(k => value(v) for (k, v) in x) - -""" - positive(val::Real, transform=exp, ε=sqrt(eps(typeof(val)))) - -Return a `Positive`. -The `value` of a `Positive` is a `Real` number that is constrained to be positive. -This is represented in terms of a `transform` that maps an `unconstrained_value` to the -positive reals. -Satisfies `val ≈ transform(unconstrained_value)`. -""" -function positive(val::Real, transform=exp, ε=sqrt(eps(typeof(val)))) - val > 0 || throw(ArgumentError("Value ($val) is not positive.")) - val > ε || throw(ArgumentError("Value ($val) is too small, relative to ε ($ε).")) - unconstrained_value = inverse(transform)(val - ε) - return Positive(unconstrained_value, transform, convert(typeof(unconstrained_value), ε)) -end - -struct Positive{T<:Real,V,Tε<:Real} <: AbstractParameter - unconstrained_value::T - transform::V - ε::Tε -end - -value(x::Positive) = x.transform(x.unconstrained_value) + x.ε - -function flatten(::Type{T}, x::Positive) where {T<:Real} - v, unflatten_to_Real = flatten(T, x.unconstrained_value) - - function unflatten_Positive(v_new::Vector{T}) - return Positive(unflatten_to_Real(v_new), x.transform, x.ε) - end - - return v, unflatten_Positive -end - -""" - bounded(val::Real, lower_bound::Real, upper_bound::Real) - -Constructs a `Bounded`. -The `value` of a `Bounded` is a `Real` number that is constrained to be within the interval -(`lower_bound`, `upper_bound`), and is equal to `val`. -This is represented internally in terms of an `unconstrained_value` and a `transform` that -maps any `Real` to this interval. -""" -function bounded(val::Real, lower_bound::Real, upper_bound::Real) - lb = convert(typeof(val), lower_bound) - ub = convert(typeof(val), upper_bound) - - # construct open interval - ε = convert(typeof(val), 1e-12) - lb_plus_ε = lb + ε - ub_minus_ε = ub - ε - - if val > ub_minus_ε || val < lb_plus_ε - throw( - ArgumentError( - "Value, $val, outside of specified bounds ($lower_bound, $upper_bound)." - ), - ) - end - - length_interval = ub_minus_ε - lb_plus_ε - unconstrained_val = logit((val - lb_plus_ε) / length_interval) - transform(x) = lb_plus_ε + length_interval * logistic(x) - - return Bounded(unconstrained_val, lb, ub, transform, ε) -end - -struct Bounded{T<:Real,V,Tε<:Real} <: AbstractParameter - unconstrained_value::T - lower_bound::T - upper_bound::T - transform::V - ε::Tε -end - -value(x::Bounded) = x.transform(x.unconstrained_value) - -function flatten(::Type{T}, x::Bounded) where {T<:Real} - v, unflatten_to_Real = flatten(T, x.unconstrained_value) - - function unflatten_Bounded(v_new::Vector{T}) - return Bounded( - unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε - ) - end - - return v, unflatten_Bounded -end - -""" - fixed(val) - -Represents a parameter whose value is required to stay constant. The `value` of a `Fixed` is -simply `val`. Constantness of the parameter is enforced by returning an empty -vector from `flatten`. -""" -fixed(val) = Fixed(val) - -struct Fixed{T} <: AbstractParameter - value::T -end - -value(x::Fixed) = value(x.value) - -function flatten(::Type{T}, x::Fixed) where {T<:Real} - unflatten_Fixed(v_new::Vector{T}) = x - return T[], unflatten_Fixed -end - -""" - deferred(f, args...) - -The `value` of a `deferred` is `f(value(args)...)`. This makes it possible to make the value -of the `args` e.g. `AbstractParameter`s and, therefore, enforce constraints on them even if -`f` knows nothing about `AbstractParameters`. - -It can be helpful to use `deferred` recursively when constructing complicated objects. -""" -deferred(f, args...) = Deferred(f, args) - -struct Deferred{Tf,Targs} <: AbstractParameter - f::Tf - args::Targs -end - -Base.:(==)(a::Deferred, b::Deferred) = (a.f == b.f) && (a.args == b.args) - -value(x::Deferred) = x.f(value(x.args)...) - -function flatten(::Type{T}, x::Deferred) where {T<:Real} - v, unflatten = flatten(T, x.args) - unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new)) - return v, unflatten_Deferred -end - -""" - nearest_orthogonal_matrix(X::StridedMatrix) - -Project `X` onto the closest orthogonal matrix in Frobenius norm. - -Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446 -""" -@inline function nearest_orthogonal_matrix(X::StridedMatrix{<:Union{Real,Complex}}) - # Inlining necessary for type inference for some reason. - U, _, V = svd(X) - return U * V' -end - -""" - orthogonal(X::StridedMatrix{<:Real}) - -Produce a parameter whose `value` is constrained to be an orthogonal matrix. The argument `X` need not -be orthogonal. - -This functionality projects `X` onto the nearest element subspace of orthogonal matrices (in -Frobenius norm) and is overparametrised as a consequence. - -Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446 -""" -orthogonal(X::StridedMatrix{<:Real}) = Orthogonal(X) - -struct Orthogonal{TX<:StridedMatrix{<:Real}} <: AbstractParameter - X::TX -end - -Base.:(==)(X::Orthogonal, Y::Orthogonal) = X.X == Y.X - -value(X::Orthogonal) = nearest_orthogonal_matrix(X.X) - -function flatten(::Type{T}, X::Orthogonal) where {T<:Real} - v, unflatten_to_Array = flatten(T, X.X) - unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new)) - return v, unflatten_Orthogonal -end - -""" - positive_definite(X::StridedMatrix{<:Real}) - -Produce a parameter whose `value` is constrained to be a positive-definite matrix. The argument `X` needs to -be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix). - -The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector. -""" -function positive_definite(X::StridedMatrix{<:Real}) - isposdef(X) || throw(ArgumentError("X is not positive-definite")) - return PositiveDefinite(tril_to_vec(cholesky(X).L)) -end - -struct PositiveDefinite{TL<:AbstractVector{<:Real}} <: AbstractParameter - L::TL -end - -Base.:(==)(X::PositiveDefinite, Y::PositiveDefinite) = X.L == Y.L - -A_At(X) = X * X' - -value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) - -function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real} - v, unflatten_v = flatten(T, X.L) - unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new)) - return v, unflatten_PositiveDefinite -end - -# Convert a vector to lower-triangular matrix -function vec_to_tril(v::AbstractVector{T}) where {T} - n_vec = length(v) - n_tril = Int((sqrt(1 + 8 * n_vec) - 1) / 2) # Infer the size of the matrix from the vector - L = zeros(T, n_tril, n_tril) - L[tril!(trues(size(L)))] = v - return L -end - -function ChainRulesCore.rrule(::typeof(vec_to_tril), v::AbstractVector{T}) where {T} - L = vec_to_tril(v) - pullback_vec_to_tril(Δ) = NoTangent(), tril_to_vec(unthunk(Δ)) - return L, pullback_vec_to_tril -end - -# Convert a lower-triangular matrix to a vector (without the zeros) -# Adapted from https://stackoverflow.com/questions/50651781/extract-lower-triangle-portion-of-a-matrix -function tril_to_vec(X::AbstractMatrix{T}) where {T} - n, m = size(X) - n == m || error("Matrix needs to be square") - return X[tril!(trues(size(X)))] -end diff --git a/src/parameters_base.jl b/src/parameters_base.jl new file mode 100644 index 0000000..61c17d4 --- /dev/null +++ b/src/parameters_base.jl @@ -0,0 +1,20 @@ +abstract type AbstractParameter end + +""" + value(x) + +Return the "value" of an object. +For `AbstractParameter`s this typically applies some transformation to some data +contained in the parameter, and returns a plain data type. +It might, for example, return a transformation of some internal data, the result of which +is guaranteed to satisfy some constraint. +""" +value(x) + +# Various basic `value` definitions. +value(x::Number) = x +value(x::AbstractArray{<:Number}) = x +value(x::AbstractArray) = map(value, x) +value(x::Tuple) = map(value, x) +value(x::NamedTuple) = map(value, x) +value(x::Dict) = Dict(k => value(v) for (k, v) in x) diff --git a/src/parameters_matrix.jl b/src/parameters_matrix.jl new file mode 100644 index 0000000..4332254 --- /dev/null +++ b/src/parameters_matrix.jl @@ -0,0 +1,91 @@ +""" + nearest_orthogonal_matrix(X::StridedMatrix) + +Project `X` onto the closest orthogonal matrix in Frobenius norm. + +Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446 +""" +@inline function nearest_orthogonal_matrix(X::StridedMatrix{<:Union{Real,Complex}}) + # Inlining necessary for type inference for some reason. + U, _, V = svd(X) + return U * V' +end + +""" + orthogonal(X::StridedMatrix{<:Real}) + +Produce a parameter whose `value` is constrained to be an orthogonal matrix. The argument `X` need not +be orthogonal. + +This functionality projects `X` onto the nearest element subspace of orthogonal matrices (in +Frobenius norm) and is overparametrised as a consequence. + +Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446 +""" +orthogonal(X::StridedMatrix{<:Real}) = Orthogonal(X) + +struct Orthogonal{TX<:StridedMatrix{<:Real}} <: AbstractParameter + X::TX +end + +Base.:(==)(X::Orthogonal, Y::Orthogonal) = X.X == Y.X + +value(X::Orthogonal) = nearest_orthogonal_matrix(X.X) + +function flatten(::Type{T}, X::Orthogonal) where {T<:Real} + v, unflatten_to_Array = flatten(T, X.X) + unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new)) + return v, unflatten_Orthogonal +end + +""" + positive_definite(X::StridedMatrix{<:Real}) + +Produce a parameter whose `value` is constrained to be a positive-definite matrix. The argument `X` needs to +be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix). + +The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector. +""" +function positive_definite(X::StridedMatrix{<:Real}) + isposdef(X) || throw(ArgumentError("X is not positive-definite")) + return PositiveDefinite(tril_to_vec(cholesky(X).L)) +end + +struct PositiveDefinite{TL<:AbstractVector{<:Real}} <: AbstractParameter + L::TL +end + +Base.:(==)(X::PositiveDefinite, Y::PositiveDefinite) = X.L == Y.L + +A_At(X) = X * X' + +value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) + +function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real} + v, unflatten_v = flatten(T, X.L) + unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new)) + return v, unflatten_PositiveDefinite +end + +# Convert a vector to lower-triangular matrix +function vec_to_tril(v::AbstractVector{T}) where {T} + n_vec = length(v) + n_tril = Int((sqrt(1 + 8 * n_vec) - 1) / 2) # Infer the size of the matrix from the vector + L = zeros(T, n_tril, n_tril) + L[tril!(trues(size(L)))] = v + return L +end + +function ChainRulesCore.rrule(::typeof(vec_to_tril), v::AbstractVector{T}) where {T} + L = vec_to_tril(v) + pullback_vec_to_tril(Δ) = NoTangent(), tril_to_vec(unthunk(Δ)) + return L, pullback_vec_to_tril +end + +# Convert a lower-triangular matrix to a vector (without the zeros) +# Adapted from https://stackoverflow.com/questions/50651781/extract-lower-triangle-portion-of-a-matrix +function tril_to_vec(X::AbstractMatrix{T}) where {T} + n, m = size(X) + n == m || error("Matrix needs to be square") + return X[tril!(trues(size(X)))] +end diff --git a/src/parameters_meta.jl b/src/parameters_meta.jl new file mode 100644 index 0000000..7de7646 --- /dev/null +++ b/src/parameters_meta.jl @@ -0,0 +1,45 @@ +""" + fixed(val) + +Represents a parameter whose value is required to stay constant. The `value` of a `Fixed` is +simply `val`. Constantness of the parameter is enforced by returning an empty +vector from `flatten`. +""" +fixed(val) = Fixed(val) + +struct Fixed{T} <: AbstractParameter + value::T +end + +value(x::Fixed) = value(x.value) + +function flatten(::Type{T}, x::Fixed) where {T<:Real} + unflatten_Fixed(v_new::Vector{T}) = x + return T[], unflatten_Fixed +end + +""" + deferred(f, args...) + +The `value` of a `deferred` is `f(value(args)...)`. This makes it possible to make the value +of the `args` e.g. `AbstractParameter`s and, therefore, enforce constraints on them even if +`f` knows nothing about `AbstractParameters`. + +It can be helpful to use `deferred` recursively when constructing complicated objects. +""" +deferred(f, args...) = Deferred(f, args) + +struct Deferred{Tf,Targs} <: AbstractParameter + f::Tf + args::Targs +end + +Base.:(==)(a::Deferred, b::Deferred) = (a.f == b.f) && (a.args == b.args) + +value(x::Deferred) = x.f(value(x.args)...) + +function flatten(::Type{T}, x::Deferred) where {T<:Real} + v, unflatten = flatten(T, x.args) + unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new)) + return v, unflatten_Deferred +end diff --git a/src/parameters_scalar.jl b/src/parameters_scalar.jl new file mode 100644 index 0000000..00ceba4 --- /dev/null +++ b/src/parameters_scalar.jl @@ -0,0 +1,88 @@ +""" + positive(val::Real, transform=exp, ε=sqrt(eps(typeof(val)))) + +Return a `Positive`. +The `value` of a `Positive` is a `Real` number that is constrained to be positive. +This is represented in terms of a `transform` that maps an `unconstrained_value` to the +positive reals. +Satisfies `val ≈ transform(unconstrained_value)`. +""" +function positive(val::Real, transform=exp, ε=sqrt(eps(typeof(val)))) + val > 0 || throw(ArgumentError("Value ($val) is not positive.")) + val > ε || throw(ArgumentError("Value ($val) is too small, relative to ε ($ε).")) + unconstrained_value = inverse(transform)(val - ε) + return Positive(unconstrained_value, transform, convert(typeof(unconstrained_value), ε)) +end + +struct Positive{T<:Real,V,Tε<:Real} <: AbstractParameter + unconstrained_value::T + transform::V + ε::Tε +end + +value(x::Positive) = x.transform(x.unconstrained_value) + x.ε + +function flatten(::Type{T}, x::Positive) where {T<:Real} + v, unflatten_to_Real = flatten(T, x.unconstrained_value) + + function unflatten_Positive(v_new::Vector{T}) + return Positive(unflatten_to_Real(v_new), x.transform, x.ε) + end + + return v, unflatten_Positive +end + +""" + bounded(val::Real, lower_bound::Real, upper_bound::Real) + +Constructs a `Bounded`. +The `value` of a `Bounded` is a `Real` number that is constrained to be within the interval +(`lower_bound`, `upper_bound`), and is equal to `val`. +This is represented internally in terms of an `unconstrained_value` and a `transform` that +maps any `Real` to this interval. +""" +function bounded(val::Real, lower_bound::Real, upper_bound::Real) + lb = convert(typeof(val), lower_bound) + ub = convert(typeof(val), upper_bound) + + # construct open interval + ε = convert(typeof(val), 1e-12) + lb_plus_ε = lb + ε + ub_minus_ε = ub - ε + + if val > ub_minus_ε || val < lb_plus_ε + throw( + ArgumentError( + "Value, $val, outside of specified bounds ($lower_bound, $upper_bound)." + ), + ) + end + + length_interval = ub_minus_ε - lb_plus_ε + unconstrained_val = logit((val - lb_plus_ε) / length_interval) + transform(x) = lb_plus_ε + length_interval * logistic(x) + + return Bounded(unconstrained_val, lb, ub, transform, ε) +end + +struct Bounded{T<:Real,V,Tε<:Real} <: AbstractParameter + unconstrained_value::T + lower_bound::T + upper_bound::T + transform::V + ε::Tε +end + +value(x::Bounded) = x.transform(x.unconstrained_value) + +function flatten(::Type{T}, x::Bounded) where {T<:Real} + v, unflatten_to_Real = flatten(T, x.unconstrained_value) + + function unflatten_Bounded(v_new::Vector{T}) + return Bounded( + unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε + ) + end + + return v, unflatten_Bounded +end diff --git a/test/parameters.jl b/test/parameters.jl index c116822..6f8e17a 100644 --- a/test/parameters.jl +++ b/test/parameters.jl @@ -1,116 +1,4 @@ -using ParameterHandling: Positive, Bounded -using ParameterHandling: vec_to_tril, tril_to_vec - -mvnormal(args...) = MvNormal(args...) -pdiagmat(args...) = PDiagMat(args...) - @testset "parameters" begin - @testset "postive" begin - @testset "$val" for val in [5.0, 0.001f0, 1.0e-7] - p = positive(val) - test_parameter_interface(p) - @test value(p) ≈ val - @test typeof(value(p)) === typeof(val) - end - - # Test edge cases around the size of the value relative to the error tol. - @test_throws ArgumentError positive(-0.1) - @test_throws ArgumentError positive(1e-11) - @test value(positive(1e-11, exp, 1e-12)) ≈ 1e-11 - end - - @testset "bounded" begin - @testset "$val" for val in [-0.05, -0.1 + 1e-12, 2.0 - 1e-11, 2.0 - 1e-12] - p = bounded(val, -0.1, 2.0) - test_parameter_interface(p) - @test value(p) ≈ val - end - - @test_throws ArgumentError bounded(-0.05, 0.0, 1.0) - end - - @testset "fixed" begin - @testset "plain" begin - val = (a=5.0, b=4.0) - p = fixed(val) - test_parameter_interface(p) - @test value(p) == val - end - - @testset "constrained" begin - val = 1.234 - constrained_val = positive(val) - p = fixed(constrained_val) - test_parameter_interface(p) - @test value(p) ≈ val - end - end - - @testset "deferred" begin - test_parameter_interface(deferred(sin, 0.5); check_inferred=tuple_infers) - test_parameter_interface(deferred(sin, positive(0.5)); check_inferred=tuple_infers) - test_parameter_interface( - deferred( - mvnormal, fixed(randn(5)), deferred(pdiagmat, positive.(rand(5) .+ 1e-1)) - ); - check_inferred=tuple_infers, - ) - end - - @testset "orthogonal" begin - is_almost_orthogonal(X::AbstractMatrix, tol) = norm(X'X - I) < tol - - @testset "nearest_orthogonal_matrix($T)" for T in [Float64, ComplexF64] - X_orth = ParameterHandling.nearest_orthogonal_matrix(randn(T, 5, 4)) - @test is_almost_orthogonal(X_orth, 1e-9) - X_orth_2 = ParameterHandling.nearest_orthogonal_matrix(X_orth) - @test X_orth ≈ X_orth_2 # nearest_orthogonal_matrix is a projection. - end - - X = orthogonal(randn(5, 4)) - @test X == X - test_parameter_interface(X) - @test is_almost_orthogonal(value(X), 1e-9) - - # We do not implement any custom rrules, so we only check that `Zygote` is able to - # differentiate, and assume that the result is correct if it doesn't error. - @testset "Zygote" begin - _, pb = Zygote.pullback(X -> value(orthogonal(X)), randn(3, 2)) - @test only(pb(randn(3, 2))) isa Matrix{<:Real} - end - end - - @testset "positive_definite" begin - @testset "vec_tril_conversion" begin - X = tril!(rand(3, 3)) - @test vec_to_tril(tril_to_vec(X)) == X - @test_throws ErrorException tril_to_vec(rand(4, 5)) - end - X_mat = ParameterHandling.A_At(rand(3, 3)) # Create a positive definite object - X = positive_definite(X_mat) - @test X == X - @test value(X) ≈ X_mat - @test isposdef(value(X)) - @test vec_to_tril(X.L) ≈ cholesky(X_mat).L - @test_throws ArgumentError positive_definite(rand(3, 3)) - test_parameter_interface(X) - - x, re = flatten(X) - Δl = first( - Zygote.gradient(x) do x - X = re(x) - return logdet(value(X)) - end, - ) - ΔL = first( - Zygote.gradient(vec_to_tril(X.L)) do L - return logdet(L * L') - end, - ) - @test vec_to_tril(Δl) == tril(ΔL) - ChainRulesTestUtils.test_rrule(vec_to_tril, x) - end - function objective_function(unflatten, flat_θ::Vector{<:Real}) θ = value(unflatten(flat_θ)) return abs2(θ.a) + abs2(θ.b) diff --git a/test/parameters_matrix.jl b/test/parameters_matrix.jl new file mode 100644 index 0000000..cd12028 --- /dev/null +++ b/test/parameters_matrix.jl @@ -0,0 +1,57 @@ +using ParameterHandling: vec_to_tril, tril_to_vec + +@testset "parameters_matrix.jl" begin + @testset "orthogonal" begin + is_almost_orthogonal(X::AbstractMatrix, tol) = norm(X'X - I) < tol + + @testset "nearest_orthogonal_matrix($T)" for T in [Float64, ComplexF64] + X_orth = ParameterHandling.nearest_orthogonal_matrix(randn(T, 5, 4)) + @test is_almost_orthogonal(X_orth, 1e-9) + X_orth_2 = ParameterHandling.nearest_orthogonal_matrix(X_orth) + @test X_orth ≈ X_orth_2 # nearest_orthogonal_matrix is a projection. + end + + X = orthogonal(randn(5, 4)) + @test X == X + test_parameter_interface(X) + @test is_almost_orthogonal(value(X), 1e-9) + + # We do not implement any custom rrules, so we only check that `Zygote` is able to + # differentiate, and assume that the result is correct if it doesn't error. + @testset "Zygote" begin + _, pb = Zygote.pullback(X -> value(orthogonal(X)), randn(3, 2)) + @test only(pb(randn(3, 2))) isa Matrix{<:Real} + end + end + + @testset "positive_definite" begin + @testset "vec_tril_conversion" begin + X = tril!(rand(3, 3)) + @test vec_to_tril(tril_to_vec(X)) == X + @test_throws ErrorException tril_to_vec(rand(4, 5)) + end + X_mat = ParameterHandling.A_At(rand(3, 3)) # Create a positive definite object + X = positive_definite(X_mat) + @test X == X + @test value(X) ≈ X_mat + @test isposdef(value(X)) + @test vec_to_tril(X.L) ≈ cholesky(X_mat).L + @test_throws ArgumentError positive_definite(rand(3, 3)) + test_parameter_interface(X) + + x, re = flatten(X) + Δl = first( + Zygote.gradient(x) do x + X = re(x) + return logdet(value(X)) + end, + ) + ΔL = first( + Zygote.gradient(vec_to_tril(X.L)) do L + return logdet(L * L') + end, + ) + @test vec_to_tril(Δl) == tril(ΔL) + ChainRulesTestUtils.test_rrule(vec_to_tril, x) + end +end diff --git a/test/parameters_meta.jl b/test/parameters_meta.jl new file mode 100644 index 0000000..d7d9bd4 --- /dev/null +++ b/test/parameters_meta.jl @@ -0,0 +1,32 @@ +@testset "parameters_meta.jl" begin + @testset "fixed" begin + @testset "plain" begin + val = (a=5.0, b=4.0) + p = fixed(val) + test_parameter_interface(p) + @test value(p) == val + end + + @testset "constrained" begin + val = 1.234 + constrained_val = positive(val) + p = fixed(constrained_val) + test_parameter_interface(p) + @test value(p) ≈ val + end + end + + @testset "deferred" begin + mvnormal(args...) = MvNormal(args...) + pdiagmat(args...) = PDiagMat(args...) + + test_parameter_interface(deferred(sin, 0.5); check_inferred=tuple_infers) + test_parameter_interface(deferred(sin, positive(0.5)); check_inferred=tuple_infers) + test_parameter_interface( + deferred( + mvnormal, fixed(randn(5)), deferred(pdiagmat, positive.(rand(5) .+ 1e-1)) + ); + check_inferred=tuple_infers, + ) + end +end diff --git a/test/parameters_scalar.jl b/test/parameters_scalar.jl new file mode 100644 index 0000000..31600a7 --- /dev/null +++ b/test/parameters_scalar.jl @@ -0,0 +1,27 @@ +using ParameterHandling: Positive, Bounded + +@testset "parameters_scalar.jl" begin + @testset "postive" begin + @testset "$val" for val in [5.0, 0.001f0, 1.0e-7] + p = positive(val) + test_parameter_interface(p) + @test value(p) ≈ val + @test typeof(value(p)) === typeof(val) + end + + # Test edge cases around the size of the value relative to the error tol. + @test_throws ArgumentError positive(-0.1) + @test_throws ArgumentError positive(1e-11) + @test value(positive(1e-11, exp, 1e-12)) ≈ 1e-11 + end + + @testset "bounded" begin + @testset "$val" for val in [-0.05, -0.1 + 1e-12, 2.0 - 1e-11, 2.0 - 1e-12] + p = bounded(val, -0.1, 2.0) + test_parameter_interface(p) + @test value(p) ≈ val + end + + @test_throws ArgumentError bounded(-0.05, 0.0, 1.0) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b864bd2..ff410d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,4 +17,7 @@ const tuple_infers = VERSION < v"1.5" ? false : true @testset "ParameterHandling.jl" begin include("flatten.jl") include("parameters.jl") + include("parameters_meta.jl") + include("parameters_scalar.jl") + include("parameters_matrix.jl") end