diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 1781fb6..e5b611d 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -44,30 +44,22 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic fit_prior_nsteps::Int = 100::(_ > 0) end -Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier} +LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :),nothing)) - +MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) function MMI.reformat(::LaplaceClassifier, X, y) X = MLJBase.matrix(X) |> permutedims + y = categorical(y) + unique_labels = y.pool.levels + y = Flux.onehotbatch(y, unique_labels) # One-hot encoding - - y= reshape(y, 1, :) - # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1, :]) .- 1 - # One-hot encoding of labels - unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding - y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - return X,(y_onehot, y[1]) + return X, y end -#MMI.selectrows(::LaplaceClassifier, I, Xmatrix, y) = (view(Xmatrix, :, I), (view(y[1],I),y[2])) -#MMI.selectrows(::LaplaceRegressor, I, Xmatrix, y) = (view(Xmatrix, :,I), (view(y[1],I),y[2]) ) -#for predict: -MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,) +MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,) @doc """ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) @@ -85,7 +77,7 @@ Fit a Laplace model using the provided features and target values. - `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history. - `report`: A Namedtuple containing the loss history of the fitting process. """ -function MMI.fit(m::Laplace_Models, verbosity, X, y) +function MMI.fit(m::LaplaceModels, verbosity, X, y) decode = y[2] y= y[1] @@ -172,7 +164,7 @@ Update the Laplace model using the provided new data points. - `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history. - `report`: A Namedtuple containing the complete loss history of the fitting process. """ -function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y) +function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) decode = y[2] y_up=y[1] @@ -302,7 +294,7 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y end @doc """ - function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) + function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...) If both `m1` and `m2` are of `MLJType`, return `true` if the following conditions all hold, and `false` otherwise: @@ -329,7 +321,7 @@ meaning; see [`deep_properties`](@ref)) for details. If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. """ -function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) +function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...) typeof(m1) === typeof(m2) || return false names = propertynames(m1) propertynames(m2) === names || return false @@ -420,7 +412,7 @@ end - `loss`: The loss value of the posterior distribution. """ -function MMI.fitted_params(model::Laplace_Models, fitresult) +function MMI.fitted_params(model::LaplaceModels, fitresult) la, decode = fitresult posterior = la.posterior return ( @@ -447,7 +439,7 @@ Retrieve the training loss history from the given `report`. # Returns - A collection representing the loss history from the training report. """ -function MMI.training_losses(model::Laplace_Models, report) +function MMI.training_losses(model::LaplaceModels, report) return report.loss_history end @@ -467,7 +459,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) for LaplaceClassifier: - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. """ -function MMI.predict(m::Laplace_Models, fitresult, Xnew) +function MMI.predict(m::LaplaceModels, fitresult, Xnew) la, decode = fitresult if typeof(m) == LaplaceRegressor yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index c225bfd..6254a7d 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -7,6 +7,7 @@ using MLJ using MLJ:predict,fit! using LaplaceRedux +cv = CV(; nfolds=3) @testset "Regression" begin flux_model = Chain( @@ -36,41 +37,40 @@ end @testset "Classification" begin -# Define the model -flux_model = Chain( - Dense(4, 10, relu), - Dense(10, 3) -) + # Define the model + flux_model = Chain( + Dense(4, 10, relu), + Dense(10, 3) + ) -model = LaplaceClassifier(model=flux_model,epochs=50) + model = LaplaceClassifier(model=flux_model,epochs=50) -X, y = @load_iris -mach = machine(model, X, y) -MLJBase.fit!(mach,verbosity=1) -Xnew = (sepal_length = [6.4, 7.2, 7.4], - sepal_width = [2.8, 3.0, 2.8], - petal_length = [5.6, 5.8, 6.1], - petal_width = [2.1, 1.6, 1.9],) -yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions -predict_mode(mach, Xnew) # point predictions -pdf.(yhat, "virginica") # probabilities for the "verginica" class -MLJBase.fitted_params(mach) # fitted params -MLJBase.training_losses(mach) #training loss history -model.epochs= 100 #changing number of epochs -MLJBase.fit!(mach) #testing update function -model.epochs= 50 #changing number of epochs to a lower number -MLJBase.fit!(mach) #testing update function -model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps -MLJBase.fit!(mach) #testing update function (the laplace part) + X, y = @load_iris + mach = machine(model, X, y) + MLJBase.fit!(mach,verbosity=1) + Xnew = (sepal_length = [6.4, 7.2, 7.4], + sepal_width = [2.8, 3.0, 2.8], + petal_length = [5.6, 5.8, 6.1], + petal_width = [2.1, 1.6, 1.9],) + yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions + predict_mode(mach, Xnew) # point predictions + pdf.(yhat, "virginica") # probabilities for the "verginica" class + MLJBase.fitted_params(mach) # fitted params + MLJBase.training_losses(mach) #training loss history + model.epochs= 100 #changing number of epochs + MLJBase.fit!(mach) #testing update function + model.epochs= 50 #changing number of epochs to a lower number + MLJBase.fit!(mach) #testing update function + model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps + MLJBase.fit!(mach) #testing update function (the laplace part) -# Define a different model -flux_model_two = Chain( - Dense(4, 6, relu), - Dense(6, 3) -) + # Define a different model + flux_model_two = Chain( + Dense(4, 6, relu), + Dense(6, 3) + ) -model.model = flux_model_two + model.model = flux_model_two -MLJBase.fit!(mach) - + MLJBase.fit!(mach) end