-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,9 @@ using Distributions: Normal | |
A mutable struct representing a Laplace regression model. | ||
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. | ||
It has the following Hyperparameters: | ||
- `flux_model`: A flux model provided by the user and compatible with the dataset. | ||
- `flux_model`: A Flux model provided by the user and compatible with the dataset. | ||
- `flux_loss` : a Flux loss function | ||
- `optimiser` = a Flux optimiser | ||
- `epochs`: The number of training epochs. | ||
- `batch_size`: The batch size. | ||
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. | ||
|
@@ -23,12 +25,12 @@ It has the following Hyperparameters: | |
- `σ`: the standard deviation of the prior distribution. | ||
- `μ₀`: the mean of the prior distribution. | ||
- `P₀`: the covariance matrix of the prior distribution. | ||
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. | ||
- `fit_prior_nsteps`: the number of steps used to fit the priors. | ||
""" | ||
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic | ||
flux_model::Flux.Chain = nothing | ||
flux_loss = Flux.Losses.mse | ||
optimiser = Adam() | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
pat-alt
Member
|
||
epochs::Integer = 1000::(_ > 0) | ||
batch_size::Integer = 32::(_ > 0) | ||
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) | ||
|
@@ -39,7 +41,6 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist | |
σ::Float64 = 1.0 | ||
This comment has been minimized.
Sorry, something went wrong.
pasq-cat
Author
Member
|
||
μ₀::Float64 = 0.0 | ||
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing | ||
#ret_distr::Bool = false::(_ in (true, false)) | ||
fit_prior_nsteps::Int = 100::(_ > 0) | ||
end | ||
|
||
|
@@ -72,12 +73,11 @@ This function performs the following steps: | |
9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success. | ||
""" | ||
function MMI.fit(m::LaplaceRegressor, verbosity, X, y) | ||
#features = Tables.schema(X).names | ||
|
||
X = MLJBase.matrix(X) |> permutedims | ||
y = reshape(y, 1, :) | ||
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) | ||
opt_state = Flux.setup(Adam(), m.flux_model) | ||
opt_state = Flux.setup(m.optimiser(), m.flux_model) | ||
|
||
for epoch in 1:(m.epochs) | ||
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y | ||
|
@@ -139,11 +139,11 @@ end | |
A mutable struct representing a Laplace Classification model. | ||
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. | ||
The model also has the following parameters: | ||
- `flux_model`: A flux model provided by the user and compatible with the dataset. | ||
- `flux_model`: A Flux model provided by the user and compatible with the dataset. | ||
- `flux_loss` : a Flux loss function | ||
- `optimiser` = a Flux optimiser | ||
- `epochs`: The number of training epochs. | ||
- `batch_size`: The batch size. | ||
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. | ||
|
@@ -154,13 +154,12 @@ The model also has the following parameters: | |
- `μ₀`: the mean of the prior distribution. | ||
- `P₀`: the covariance matrix of the prior distribution. | ||
- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. | ||
- `predict_proba`: a boolean that select whether to predict probabilities or not. | ||
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. | ||
- `fit_prior_nsteps`: the number of steps used to fit the priors. | ||
""" | ||
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic | ||
flux_model::Flux.Chain = nothing | ||
flux_loss = Flux.Losses.logitcrossentropy | ||
optimiser = Adam() | ||
epochs::Integer = 1000::(_ > 0) | ||
batch_size::Integer = 32::(_ > 0) | ||
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) | ||
|
@@ -171,7 +170,6 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis | |
σ::Float64 = 1.0 | ||
μ₀::Float64 = 0.0 | ||
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing | ||
#ret_distr::Bool = false::(_ in (true, false)) | ||
fit_prior_nsteps::Int = 100::(_ > 0) | ||
link_approx::Symbol = :probit::(_ in (:probit, :plugin)) | ||
end | ||
|
@@ -207,7 +205,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) | |
y_plain = MLJBase.int(y) .- 1 | ||
y_onehot = Flux.onehotbatch(y_plain, unique(y_plain)) | ||
data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) | ||
opt_state = Flux.setup(Adam(), m.flux_model) | ||
opt_state = Flux.setup(m.optimiser, m.flux_model) | ||
|
||
for epoch in 1:(m.epochs) | ||
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot | ||
|
@pat-alt i did not find what kind of default type to set for these 2 parametersso i just defined the default values..