-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Documenter.jl
committed
Sep 3, 2024
1 parent
fc867b1
commit d3e1efa
Showing
73 changed files
with
31,953 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
v1.0.2 | ||
v1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
v1.0.2 | ||
v1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
v1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-09-03T11:21:04","documenter_version":"1.6.0"}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
|
||
# LaplaceRedux | ||
|
||
```{julia} | ||
#| echo: false | ||
using Pkg; Pkg.activate("docs") | ||
# Import libraries | ||
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra | ||
``` | ||
|
||
`LaplaceRedux.jl` is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python [library](https://aleximmer.github.io/Laplace/index.html#setup) and its companion [paper](https://arxiv.org/abs/2106.14806) [@daxberger2021laplace]. | ||
|
||
## 🚩 Installation | ||
|
||
The stable version of this package can be installed as follows: | ||
|
||
```{.julia} | ||
using Pkg | ||
Pkg.add("LaplaceRedux.jl") | ||
``` | ||
|
||
The development version can be installed like so: | ||
|
||
```{.julia} | ||
using Pkg | ||
Pkg.add("https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl") | ||
``` | ||
|
||
## 🏃 Getting Started | ||
|
||
If you are new to Deep Learning in Julia or simply prefer learning through videos, check out this awesome YouTube [tutorial](https://www.youtube.com/channel/UCQwQVlIkbalDzmMnr-0tRhw) by [doggo.jl](https://www.youtube.com/@doggodotjl/about) 🐶. Additionally, you can also find a [video](https://www.youtube.com/watch?v=oWko8FRj_64) of my presentation at JuliaCon 2022 on YouTube. | ||
|
||
## 🖥️ Basic Usage | ||
|
||
`LaplaceRedux.jl` can be used for any neural network trained in [`Flux.jl`](https://fluxml.ai/Flux.jl/dev/). Below we show basic usage examples involving two simple models for a regression and a classification task, respectively. | ||
|
||
### Regression | ||
|
||
```{julia} | ||
#| echo: false | ||
using LaplaceRedux | ||
using LaplaceRedux.Data: toy_data_regression | ||
using Flux.Optimise: update!, Adam | ||
# Data: | ||
n = 150 # number of observations | ||
σtrue = 0.3 # true observational noise | ||
x, y = toy_data_regression(n; noise=σtrue) | ||
xs = [[x] for x in x] | ||
X = permutedims(x) | ||
data = zip(xs,y) | ||
# Model: | ||
n_hidden = 50 | ||
D = size(X,1) | ||
nn = Chain( | ||
Dense(D, n_hidden, tanh), | ||
Dense(n_hidden, 1) | ||
) | ||
loss(x, y) = Flux.Losses.mse(nn(x), y) | ||
# Training: | ||
opt = Adam(1e-3) | ||
epochs = 1000 | ||
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data)) | ||
show_every = epochs/10 | ||
for epoch = 1:epochs | ||
for d in data | ||
gs = gradient(Flux.params(nn)) do | ||
l = loss(d...) | ||
end | ||
update!(opt, Flux.params(nn), gs) | ||
end | ||
if epoch % show_every == 0 | ||
println("Epoch " * string(epoch)) | ||
@show avg_loss(data) | ||
end | ||
end | ||
``` | ||
|
||
A complete worked example for a regression model can be found in the [docs](https://www.paltmeyer.com/LaplaceRedux.jl/dev/tutorials/regression/). Here we jump straight to Laplace Approximation and take the pre-trained model `nn` as given. Then LA can be implemented as follows, where we specify the model `likelihood`. The plot shows the fitted values overlaid with a 95% confidence interval. As expected, predictive uncertainty quickly increases in areas that are not populated by any training data. | ||
|
||
```{julia} | ||
#| output: true | ||
la = Laplace(nn; likelihood=:regression) | ||
fit!(la, data) | ||
optimize_prior!(la) | ||
plot(la, X, y; zoom=-5, size=(500,500)) | ||
``` | ||
|
||
```{julia} | ||
#| echo: false | ||
using Plots.PlotMeasures | ||
theme(:wong) | ||
anim = Animation() | ||
N = 100 | ||
_every = Int(round(N/10)) | ||
zoom = -3 | ||
for n in _every:_every:length(first(data,N)) | ||
la = Laplace(nn; likelihood=:regression) | ||
fit!(la, Iterators.take(data,n)) | ||
optimize_prior!(la) | ||
plt = plot( | ||
la, X[:,1:n], y[1:n]; | ||
ylim=(minimum(y)+zoom,maximum(y)-zoom), | ||
zoom=-2, size=(600,200), clegend=false, axis=nothing, legend=false, margin = -5mm, ms=4 | ||
) | ||
frame(anim, plt) | ||
end | ||
gif(anim, "dev/www/intro.gif", fps=1) | ||
``` | ||
|
||
### Binary Classification | ||
|
||
```{julia} | ||
#| echo: false | ||
using LaplaceRedux.Data: toy_data_non_linear | ||
# Data: | ||
xs, ys = toy_data_non_linear(200) | ||
X = hcat(xs...) # bring into tabular format | ||
data = zip(xs,ys) | ||
# Model: | ||
n_hidden = 10 | ||
D = size(X,1) | ||
nn = Chain( | ||
Dense(D, n_hidden, σ), | ||
Dense(n_hidden, 1) | ||
) | ||
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) | ||
# Training: | ||
opt = Adam() | ||
epochs = 500 | ||
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data)) | ||
show_every = epochs/10 | ||
for epoch = 1:epochs | ||
for d in data | ||
gs = gradient(Flux.params(nn)) do | ||
l = loss(d...) | ||
end | ||
update!(opt, Flux.params(nn), gs) | ||
end | ||
if epoch % show_every == 0 | ||
println("Epoch " * string(epoch)) | ||
@show avg_loss(data) | ||
end | ||
end | ||
``` | ||
|
||
Once again we jump straight to LA and refer to the [docs](https://www.paltmeyer.com/LaplaceRedux.jl/dev/tutorials/mlp/) for a complete worked example involving binary classification. In this case we need to specify `likelihood=:classification`. The plot below shows the resulting posterior predictive distributions as contours in the two-dimensional feature space: note how the **Plugin** Approximation on the left compares to the Laplace Approximation on the right. | ||
|
||
```{julia} | ||
#| output: true | ||
la = Laplace(nn; likelihood=:classification) | ||
fit!(la, data) | ||
la_untuned = deepcopy(la) # saving for plotting | ||
optimize_prior!(la; n_steps=100) | ||
# Plot the posterior predictive distribution: | ||
zoom=0 | ||
p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1)) | ||
p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) | ||
p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) | ||
plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400)) | ||
``` | ||
|
||
|
||
## 📢 JuliaCon 2022 | ||
|
||
This project was presented at JuliaCon 2022 in July 2022. See [here](https://pretalx.com/juliacon-2022/talk/Z7MXFS/) for details. | ||
|
||
## 🛠️ Contribute | ||
|
||
Contributions are very much welcome! Please follow the [SciML ColPrac guide](https://github.com/SciML/ColPrac). You may want to start by having a look at any open [issues](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues). | ||
|
||
## 🎓 References | ||
|
||
|
Oops, something went wrong.