diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 77409e1..3e16628 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -13,7 +13,9 @@ using Distributions: Normal A mutable struct representing a Laplace regression model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. It has the following Hyperparameters: -- `flux_model`: A flux model provided by the user and compatible with the dataset. +- `flux_model`: A Flux model provided by the user and compatible with the dataset. +- `flux_loss` : a Flux loss function +- `optimiser` = a Flux optimiser - `epochs`: The number of training epochs. - `batch_size`: The batch size. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. @@ -23,12 +25,12 @@ It has the following Hyperparameters: - `σ`: the standard deviation of the prior distribution. - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic flux_model::Flux.Chain = nothing flux_loss = Flux.Losses.mse + optimiser = Adam() epochs::Integer = 1000::(_ > 0) batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -39,7 +41,6 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - #ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) end @@ -72,12 +73,11 @@ This function performs the following steps: 9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success. """ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) - #features = Tables.schema(X).names X = MLJBase.matrix(X) |> permutedims y = reshape(y, 1, :) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) - opt_state = Flux.setup(Adam(), m.flux_model) + opt_state = Flux.setup(m.optimiser(), m.flux_model) for epoch in 1:(m.epochs) Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y @@ -139,11 +139,11 @@ end A mutable struct representing a Laplace Classification model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - - The model also has the following parameters: -- `flux_model`: A flux model provided by the user and compatible with the dataset. +- `flux_model`: A Flux model provided by the user and compatible with the dataset. +- `flux_loss` : a Flux loss function +- `optimiser` = a Flux optimiser - `epochs`: The number of training epochs. - `batch_size`: The batch size. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. @@ -154,13 +154,12 @@ The model also has the following parameters: - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `predict_proba`: a boolean that select whether to predict probabilities or not. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic flux_model::Flux.Chain = nothing flux_loss = Flux.Losses.logitcrossentropy + optimiser = Adam() epochs::Integer = 1000::(_ > 0) batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -171,7 +170,6 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - #ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end @@ -207,7 +205,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) y_plain = MLJBase.int(y) .- 1 y_onehot = Flux.onehotbatch(y_plain, unique(y_plain)) data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) - opt_state = Flux.setup(Adam(), m.flux_model) + opt_state = Flux.setup(m.optimiser, m.flux_model) for epoch in 1:(m.epochs) Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot