From 8edccacc2ecc80291375dc3e29ffd5091da0be63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Tue, 23 May 2023 14:15:36 +0200 Subject: [PATCH] Clarify finiteness assumptions, add function to check. Fixes #102. --- docs/src/index.md | 1 + src/LogDensityProblems.jl | 72 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 30 ++++++++++++++-- 3 files changed, 101 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 95ab3c9..eecf965 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -185,4 +185,5 @@ LogDensityProblems.dimension LogDensityProblems.logdensity LogDensityProblems.logdensity_and_gradient LogDensityProblems.logdensity_gradient_and_hessian +LogDensityProblems.is_valid_result ``` diff --git a/src/LogDensityProblems.jl b/src/LogDensityProblems.jl index 2cea958..14f444b 100644 --- a/src/LogDensityProblems.jl +++ b/src/LogDensityProblems.jl @@ -86,6 +86,7 @@ Evaluate the log density `ℓ` at `x`, which has length compatible with its Return a real number, which may or may not be finite (can also be `NaN`). Non-finite values other than `-Inf` are invalid but do not error, caller should deal with these appropriately. +Cf [`is_valid_result`](@ref). # Note about constants @@ -123,6 +124,10 @@ Return two values: The first argument (the log density) can be shifted by a constant, see the note for [`logdensity`](@ref). + +Caller should be prepared to handle non-finite derivatives, even if they are incorrect. +Cf [`is_valid_result`](@ref). + """ function logdensity_and_gradient end @@ -148,9 +153,76 @@ Return three values: The first argument (the log density) can be shifted by a constant, see the note for [`logdensity`](@ref). + +Caller should be prepared to handle non-finite derivatives, even if they are incorrect. +Cf [`is_valid_result`](@ref). """ function logdensity_gradient_and_hessian end +""" + is_valid_result(f, [∇f], [∇²f])::Bool + +Return `true` if and only if the log density `y` and its derivaties (optional) are *valid* in the sense defined below, otherwise `false`. + +# Discussion + +The API of this package defines an *interface* for working with log densities and gradients, but since the latter are implemented by the user and/or AD frameworks, it cannot impose *correctness* of these results. This function allows the caller to check for some common numerical problems conveniently. + +Ideally, log densities are almost everywhere finite and differentiable, but practical computation often violates this assumption. The caller should be prepared to deal with this, either by throwing an error, rejecting that point, or some other way. Caller functions may of course allow the user to skip this check, which may result in a minor speedup, but could lead to bugs that are very hard to diagnose (as eg propagation of `NaN`s could cause problems much later in the code). + +An example using this function would be +```julia +ℓq, ∇ℓq = logdensity_and_gradient(ℓ, q) +if is_valid_result(ℓq, ∇ℓq) + # everything is finite, or log density is -Inf, proceed accordingly + # ... +elseif !strict # an option in the API of the caller + # something went wrong, but proceed and treat it an `-Inf` + # ... +elseif is_valid_result(ℓq) + error("Gradient has non-finite elements.") +else + error("Invalid log posterior.") +end +``` + +# Definitions + +Log densities are *valid* if they are *finite* real numbers or equal to ``-\\infty``. + +Derivatives are *valid* if all elements are finite. But for ``-\\infty`` log density, the derivatives should be *ignored*. + +*All other possibilities are invalid*, including + +1. log densities that are not `::Real` (eg `1+2im`, `missing`), +2. non-finite log densities that are not `-Inf` (eg `NaN`, `Inf`), +3. derivatives (gradients or Hessians) with non-finite elements for finite log densities + +Note that this function does not check + +1. *dimensions* --- it is assumed that those kind of bugs are much more rare in AD implementations. + +2. *symmetry of the Hessian* (cf Schwarz's/Clairaut's/Young's theorem). +""" +function is_valid_result end + +# a version of `isfinite` that is only true for real numbers +_is_finite(x) = x isa Real && isfinite(x) + +is_valid_result(args...) = false + +is_valid_result(x::Real) = _is_finite(x) || x == -Inf + +function is_valid_result(x::Real, ∇x) + is_valid_result(x) || return false + x == -Inf || (∇x isa AbstractVector && all(_is_finite, ∇x)) +end + +function is_valid_result(x::Real, ∇x, ∇²x) + is_valid_result(x) || return false + x == -Inf || (∇x isa AbstractVector && all(_is_finite, ∇x) && ∇²x isa AbstractMatrix && all(_is_finite, ∇²x)) +end + include("utilities.jl") end # module diff --git a/test/runtests.jl b/test/runtests.jl index 45875e8..0eb2f46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using LogDensityProblems, Test, Random import LogDensityProblems: capabilities, dimension, logdensity -using LogDensityProblems: logdensity_and_gradient, LogDensityOrder +using LogDensityProblems: logdensity_and_gradient, LogDensityOrder, is_valid_result #### #### test setup and utilities @@ -79,9 +79,35 @@ Base.show(io::IO, ::TestLogDensity) = print(io, "TestLogDensity") end #### -#### utilities +#### valid results #### +@testset "valid results" begin + @test is_valid_result(1.0) + @test is_valid_result(-Inf) + @test !is_valid_result(Inf) + @test !is_valid_result(NaN) + @test !is_valid_result(missing) + @test !is_valid_result("a fish") + + @test is_valid_result(1.0, [2.0, 3.0]) # all finite + @test !is_valid_result(Inf, [2.0, 3.0]) # invalid + @test !is_valid_result(NaN, [2.0, 3.0]) # invalid + @test is_valid_result(-Inf, [NaN, Inf]) # gradient ignored + @test is_valid_result(-Inf, "wrong type") # wrong type but ignored + @test is_valid_result(-Inf, ["wrong element", 1.0]) # gradient ignored + + @test is_valid_result(1.0, [2.0, 3], [4.0 5; 6 7]) # non-symmetric but OK + @test !is_valid_result(Inf, [2.0, 3], [4.0 5; 6 7]) # invalid + @test !is_valid_result(NaN, [2.0, 3], [4.0 5; 6 7]) # invalid + @test !is_valid_result(:a_fish, [2.0, 3], [4.0 5; 6 7]) # invalid + @test is_valid_result(-Inf, [2.0, 3], [NaN 5; 6 7]) # Hessian ignored + @test is_valid_result(-Inf, "bad to the", :bone) +end + +#### +#### utilities +#### @testset "stresstest" begin @info "stress testing"