Skip to content

Commit

Permalink
removed predict_proba and ret_Distr from the struct
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Sep 21, 2024
1 parent 33d84f5 commit 9731297
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ using Distributions: Normal
A mutable struct representing a Laplace regression model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
It has the following Hyperparameters:
- `flux_model`: A flux model provided by the user and compatible with the dataset.
- `flux_model`: A Flux model provided by the user and compatible with the dataset.
- `flux_loss` : a Flux loss function
- `optimiser` = a Flux optimiser
- `epochs`: The number of training epochs.
- `batch_size`: The batch size.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
Expand All @@ -23,12 +25,12 @@ It has the following Hyperparameters:
- `σ`: the standard deviation of the prior distribution.
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.mse
optimiser = Adam()

This comment has been minimized.

Copy link
@pasq-cat

pasq-cat Sep 21, 2024

Author Member

@pat-alt i did not find what kind of default type to set for these 2 parametersso i just defined the default values..

This comment has been minimized.

Copy link
@pat-alt

pat-alt Sep 23, 2024

Member

You mean the optimiser or which parameters?

Would just go with the ones specified in MLJFlux, even though we're now interfacing MLJ directly: https://github.com/FluxML/MLJFlux.jl/blob/945016dc72abf0847c1551d7550b079e447b4c7c/src/types.jl#L34

This comment has been minimized.

Copy link
@pasq-cat

pasq-cat Sep 23, 2024

Author Member

optimiser and loss. I do not understand what he did in MLJFlux b,f,o,l are just letters that he set as type but they are not defined anywhere? do you want something lime `optimiser::o = Adam() and flux_loss::l= flux.Losses.mse?

epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
Expand All @@ -39,7 +41,6 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
σ::Float64 = 1.0

This comment has been minimized.

Copy link
@pasq-cat

pasq-cat Sep 21, 2024

Author Member

@pat-alt i removed ret_distr and predict_proba from the parameters that the user can set since

1)to work with MLJ predict_proba has to always be set to true
2) there is a mismatch between the distributions returned by the predict function and the one required by mlj for the classification case. To avoid this issue the interface use the default predict behavior( no distributions from distribution.jl) and then return the UnivariateFinite object

This comment has been minimized.

Copy link
@pat-alt

pat-alt Sep 23, 2024

Member

Looks good to me!

μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
#ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end

Expand Down Expand Up @@ -72,12 +73,11 @@ This function performs the following steps:
9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success.
"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
#features = Tables.schema(X).names

X = MLJBase.matrix(X) |> permutedims
y = reshape(y, 1, :)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
opt_state = Flux.setup(m.optimiser(), m.flux_model)

for epoch in 1:(m.epochs)
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y
Expand Down Expand Up @@ -139,11 +139,11 @@ end
A mutable struct representing a Laplace Classification model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model also has the following parameters:
- `flux_model`: A flux model provided by the user and compatible with the dataset.
- `flux_model`: A Flux model provided by the user and compatible with the dataset.
- `flux_loss` : a Flux loss function
- `optimiser` = a Flux optimiser
- `epochs`: The number of training epochs.
- `batch_size`: The batch size.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
Expand All @@ -154,13 +154,12 @@ The model also has the following parameters:
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
- `predict_proba`: a boolean that select whether to predict probabilities or not.
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.logitcrossentropy
optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
Expand All @@ -171,7 +170,6 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
#ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
Expand Down Expand Up @@ -207,7 +205,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
y_plain = MLJBase.int(y) .- 1
y_onehot = Flux.onehotbatch(y_plain, unique(y_plain))
data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
opt_state = Flux.setup(m.optimiser, m.flux_model)

for epoch in 1:(m.epochs)
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot
Expand Down

0 comments on commit 9731297

Please sign in to comment.