From 35ac2d8c25ebca59518754889b3c2a22f7796e29 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 13 Sep 2024 06:08:06 +0200 Subject: [PATCH] changes --- src/direct_mlj.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index d341f4a..cb4c3c1 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -25,7 +25,7 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux.Chain= nothing + flux_model::Flux.Chain = nothing subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -39,7 +39,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist end -function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) +function MMI.fit(m::LaplaceRegressor, verbosity, X, y) features = Tables.schema(X).names X = MLJBase.matrix(X) @@ -47,15 +47,17 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) la = LaplaceRedux.Laplace( m.flux_model; likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, σ=m.σ, μ₀=m.μ₀, P₀=m.P₀, ) + println(la) + # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) @@ -122,7 +124,7 @@ end -function MMI.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing) +function MMI.fit(m::LaplaceClassifier, verbosity, X, y) features = Tables.schema(X).names Xmatrix = MLJBase.matrix(X) decode = y[1]