Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

116 make plot and predict consistent #118

Merged
merged 6 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.3"
julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca"

Expand Down Expand Up @@ -85,9 +85,9 @@ version = "3.5.1+1"

[[deps.ArrayInterface]]
deps = ["Adapt", "LinearAlgebra"]
git-tree-sha1 = "f54c23a5d304fb87110de62bace7777d59088c34"
git-tree-sha1 = "3640d077b6dafd64ceb8fd5c1ec76f7ca53bcf76"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.15.0"
version = "7.16.0"

[deps.ArrayInterface.extensions]
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
Expand Down Expand Up @@ -209,9 +209,9 @@ version = "0.9.2+0"

[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
git-tree-sha1 = "f3b237289a5a77c759b2dd5d4c2ff641d67c4030"
git-tree-sha1 = "33576c7c1b2500f8e7e6baa082e04563203b3a45"
uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
version = "0.3.4"
version = "0.3.5"

[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
Expand Down Expand Up @@ -359,17 +359,18 @@ uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
version = "0.1.13"

[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "a33b7ced222c6165f624a3f2b55945fac5a598d9"
git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.5.7"
version = "1.5.8"

[deps.ConstructionBase.extensions]
ConstructionBaseIntervalSetsExt = "IntervalSets"
ConstructionBaseLinearAlgebraExt = "LinearAlgebra"
ConstructionBaseStaticArraysExt = "StaticArrays"

[deps.ConstructionBase.weakdeps]
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.ContextVariablesX]]
Expand Down Expand Up @@ -569,19 +570,24 @@ uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
version = "1.16.3"

[[deps.FilePathsBase]]
deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
deps = ["Compat", "Dates"]
git-tree-sha1 = "7878ff7172a8e6beedd1dea14bd27c3c6340d361"
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
version = "0.9.21"
version = "0.9.22"
weakdeps = ["Mmap", "Test"]

[deps.FilePathsBase.extensions]
FilePathsBaseMmapExt = "Mmap"
FilePathsBaseTestExt = "Test"

[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[deps.FillArrays]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "fd0002c0b5362d7eb952450ad5eb742443340d6e"
git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.12.0"
version = "1.13.0"
weakdeps = ["PDMats", "SparseArrays", "Statistics"]

[deps.FillArrays.extensions]
Expand Down Expand Up @@ -841,10 +847,10 @@ uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"

[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"]
git-tree-sha1 = "67d4690d32c22e28818a434b293a374cc78473d3"
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
version = "0.4.51"
version = "0.4.53"

[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
Expand All @@ -854,9 +860,9 @@ version = "0.1.8"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"
version = "1.6.0"

[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
Expand Down Expand Up @@ -884,9 +890,9 @@ version = "0.2.4"

[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "35ceea58aa34ad08b1ae00a52622c62d1cfb8ce2"
git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.24"
version = "0.9.25"

[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
Expand Down Expand Up @@ -1444,9 +1450,9 @@ version = "1.4.1"

[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf"
git-tree-sha1 = "45470145863035bb124ca51b320ed35d071cc6c2"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.40.5"
version = "1.40.8"

[deps.Plots.extensions]
FileIOExt = "FileIO"
Expand Down Expand Up @@ -1514,9 +1520,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.10.2"

[[deps.PtrArrays]]
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.2.0"
version = "1.2.1"

[[deps.Qt6Base_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"]
Expand Down Expand Up @@ -1544,9 +1550,15 @@ version = "6.7.1+1"

[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "e237232771fdafbae3db5c31275303e056afaa9f"
git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.10.1"
version = "2.11.0"

[deps.QuadGK.extensions]
QuadGKEnzymeExt = "Enzyme"

[deps.QuadGK.weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[[deps.RData]]
deps = ["CategoricalArrays", "CodecZlib", "DataFrames", "Dates", "FileIO", "Requires", "TimeZones", "Unicode"]
Expand Down Expand Up @@ -2274,7 +2286,7 @@ version = "0.15.2+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"
version = "5.11.0+0"

[[deps.libdecor_jll]]
deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"]
Expand Down
65 changes: 45 additions & 20 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,24 @@ function has_softmax_or_sigmoid_final_layer(model::Flux.Chain)
return has_finaliser
end

"""
@doc raw"""
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)

Compute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure.
Computes the functional variance for the GLM predictive as `map(j -> (j' * Σ * j), eachrow(𝐉))` which is a (output x output) predictive covariance matrix. Formally, we have ``{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}`` where ``\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta`` is the Jacobian evaluated at the MAP estimate.

Dispatches to the appropriate method based on the Hessian structure.
"""
function functional_variance(la, 𝐉)
return functional_variance(la, la.est_params.hessian_structure, 𝐉)
end

"""
@doc raw"""
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)

Computes the linearized GLM predictive.
Computes the linearized GLM predictive from neural network with a Laplace approximation to the posterior ``p(\theta|\mathcal{D})=\mathcal{N}(\hat\theta,\Sigma)``.
This is the distribution on network outputs given by ``p(f(x)|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta})``.
For the Bayesian predictive distribution, see [`predict`](@ref).


# Arguments

Expand All @@ -49,7 +54,7 @@ Computes the linearized GLM predictive.

# Examples

```julia-repl
```julia
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
Expand All @@ -58,42 +63,55 @@ nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
glm_predictive_distribution(la, hcat(x...))
```
"""
function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X)
fμ = reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
fstd = sqrt.(fvar)
normal_distr = [Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)]
normal_distr = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)]
return (normal_distr, fμ, fvar)
end

"""
predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)
@doc raw"""
predict(
la::AbstractLaplace,
X::AbstractArray;
link_approx=:probit,
predict_proba::Bool=true,
ret_distr::Bool=false,
)

Computes predictions from Bayesian neural network.
Computes the Bayesian predictivie distribution from a neural network with a Laplace approximation to the posterior ``p(\theta | \mathcal{D}) = \mathcal{N}(\hat\theta, \Sigma)``.

# Arguments

- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model.
- `return_distr::Bool=false`: if `false` (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true).
- `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`).
if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

# Returns
For classification tasks, LaplaceRedux provides different options:
if ret_distr is false:
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
if ret_distr is true:
- a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

For classification tasks:

1. If `ret_distr` is `false`, `predict` returns `fμ`, i.e. the mean of the predictive distribution, which corresponds to the MAP estimate if the link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
2. If `ret_distr` is `true`, `predict` returns a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

For regression tasks:
- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution.

1. If `ret_distr` is `false`, `predict` returns the mean and the variance of the predictive distribution. The output shape is column-major as in Flux.
2. If `ret_distr` is `true`, `predict` returns the predictive posterior distribution, namely:

``p(y|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta} + \sigma^2 \mathbf{I})``

# Examples

```julia-repl
```julia
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
Expand All @@ -111,15 +129,22 @@ function predict(
predict_proba::Bool=true,
ret_distr::Bool=false,
)
normal_distr, fμ, fvar = glm_predictive_distribution(la, X)
_, fμ, fvar = glm_predictive_distribution(la, X)

# Regression:
if la.likelihood == :regression

# Add observational noise:
pred_var = fvar .+ la.prior.σ^2
fstd = sqrt.(pred_var)
pred_dist = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)]

if ret_distr
return reshape(normal_distr, (:, 1))
return reshape(pred_dist, (:, 1))
else
return fμ, fvar
return fμ, pred_var
end

end

# Classification:
Expand Down
6 changes: 3 additions & 3 deletions src/full.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ function _fit!(
return la.posterior.n_data = n_data
end

"""
functional_variance(la::Laplace,𝐉)
@doc raw"""
functional_variance(la::Laplace, hessian_structure::FullHessian, 𝐉)

Compute the linearized GLM predictive variance as `𝐉ₙΣ𝐉ₙ'` where `𝐉=∇f(x;θ)|θ̂` is the Jacobian evaluated at the MAP estimate and `Σ = P⁻¹`.
Computes the functional variance for the GLM predictive as `map(j -> (j' * Σ * j), eachrow(𝐉))` which is a (output x output) predictive covariance matrix. Formally, we have ``{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}`` where ``\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta`` is the Jacobian evaluated at the MAP estimate.
"""
function functional_variance(la::Laplace, hessian_structure::FullHessian, 𝐉)
Σ = posterior_covariance(la)
Expand Down
2 changes: 1 addition & 1 deletion src/kronecker/kron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function _fit!(
end

"""
functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix)
functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix)

Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ,
where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output `𝐉=∇f(x;θ)|θ̂`.
Expand Down
Loading