Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify finiteness assumptions, add function to check. #105

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@ LogDensityProblems.dimension
LogDensityProblems.logdensity
LogDensityProblems.logdensity_and_gradient
LogDensityProblems.logdensity_gradient_and_hessian
LogDensityProblems.is_valid_result
```
72 changes: 72 additions & 0 deletions src/LogDensityProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Evaluate the log density `ℓ` at `x`, which has length compatible with its

Return a real number, which may or may not be finite (can also be `NaN`). Non-finite values
other than `-Inf` are invalid but do not error, caller should deal with these appropriately.
Cf [`is_valid_result`](@ref).

# Note about constants

Expand Down Expand Up @@ -123,6 +124,10 @@ Return two values:

The first argument (the log density) can be shifted by a constant, see the note for
[`logdensity`](@ref).

Caller should be prepared to handle non-finite derivatives, even if they are incorrect.
Cf [`is_valid_result`](@ref).

"""
function logdensity_and_gradient end

Expand All @@ -148,9 +153,76 @@ Return three values:

The first argument (the log density) can be shifted by a constant, see the note for
[`logdensity`](@ref).

Caller should be prepared to handle non-finite derivatives, even if they are incorrect.
Cf [`is_valid_result`](@ref).
"""
function logdensity_gradient_and_hessian end

"""
is_valid_result(f, [∇f], [∇²f])::Bool

Return `true` if and only if the log density `y` and its derivaties (optional) are *valid* in the sense defined below, otherwise `false`.

# Discussion

The API of this package defines an *interface* for working with log densities and gradients, but since the latter are implemented by the user and/or AD frameworks, it cannot impose *correctness* of these results. This function allows the caller to check for some common numerical problems conveniently.

Ideally, log densities are almost everywhere finite and differentiable, but practical computation often violates this assumption. The caller should be prepared to deal with this, either by throwing an error, rejecting that point, or some other way. Caller functions may of course allow the user to skip this check, which may result in a minor speedup, but could lead to bugs that are very hard to diagnose (as eg propagation of `NaN`s could cause problems much later in the code).

An example using this function would be
```julia
ℓq, ∇ℓq = logdensity_and_gradient(ℓ, q)
if is_valid_result(ℓq, ∇ℓq)
# everything is finite, or log density is -Inf, proceed accordingly
# ...
elseif !strict # an option in the API of the caller
# something went wrong, but proceed and treat it an `-Inf`
# ...
elseif is_valid_result(ℓq)
error("Gradient has non-finite elements.")
else
error("Invalid log posterior.")
end
```

# Definitions

Log densities are *valid* if they are *finite* real numbers or equal to ``-\\infty``.

Derivatives are *valid* if all elements are finite. But for ``-\\infty`` log density, the derivatives should be *ignored*.

*All other possibilities are invalid*, including

1. log densities that are not `::Real` (eg `1+2im`, `missing`),
2. non-finite log densities that are not `-Inf` (eg `NaN`, `Inf`),
3. derivatives (gradients or Hessians) with non-finite elements for finite log densities

Note that this function does not check

1. *dimensions* --- it is assumed that those kind of bugs are much more rare in AD implementations.

2. *symmetry of the Hessian* (cf Schwarz's/Clairaut's/Young's theorem).
"""
function is_valid_result end

# a version of `isfinite` that is only true for real numbers
_is_finite(x) = x isa Real && isfinite(x)

is_valid_result(args...) = false

is_valid_result(x::Real) = _is_finite(x) || x == -Inf

function is_valid_result(x::Real, ∇x)
is_valid_result(x) || return false
x == -Inf || (∇x isa AbstractVector && all(_is_finite, ∇x))
end

function is_valid_result(x::Real, ∇x, ∇²x)
is_valid_result(x) || return false
x == -Inf || (∇x isa AbstractVector && all(_is_finite, ∇x) && ∇²x isa AbstractMatrix && all(_is_finite, ∇²x))
end

include("utilities.jl")

end # module
30 changes: 28 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LogDensityProblems, Test, Random
import LogDensityProblems: capabilities, dimension, logdensity
using LogDensityProblems: logdensity_and_gradient, LogDensityOrder
using LogDensityProblems: logdensity_and_gradient, LogDensityOrder, is_valid_result

####
#### test setup and utilities
Expand Down Expand Up @@ -79,9 +79,35 @@ Base.show(io::IO, ::TestLogDensity) = print(io, "TestLogDensity")
end

####
#### utilities
#### valid results
####

@testset "valid results" begin
@test is_valid_result(1.0)
@test is_valid_result(-Inf)
@test !is_valid_result(Inf)
@test !is_valid_result(NaN)
@test !is_valid_result(missing)
@test !is_valid_result("a fish")

@test is_valid_result(1.0, [2.0, 3.0]) # all finite
@test !is_valid_result(Inf, [2.0, 3.0]) # invalid
@test !is_valid_result(NaN, [2.0, 3.0]) # invalid
@test is_valid_result(-Inf, [NaN, Inf]) # gradient ignored
@test is_valid_result(-Inf, "wrong type") # wrong type but ignored
@test is_valid_result(-Inf, ["wrong element", 1.0]) # gradient ignored

@test is_valid_result(1.0, [2.0, 3], [4.0 5; 6 7]) # non-symmetric but OK
@test !is_valid_result(Inf, [2.0, 3], [4.0 5; 6 7]) # invalid
@test !is_valid_result(NaN, [2.0, 3], [4.0 5; 6 7]) # invalid
@test !is_valid_result(:a_fish, [2.0, 3], [4.0 5; 6 7]) # invalid
@test is_valid_result(-Inf, [2.0, 3], [NaN 5; 6 7]) # Hessian ignored
@test is_valid_result(-Inf, "bad to the", :bone)
end

####
#### utilities
####

@testset "stresstest" begin
@info "stress testing"
Expand Down