Skip to content

Commit

Permalink
returning one-hot encoded directly
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Oct 16, 2024
1 parent 7c4d744 commit f872d96
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 54 deletions.
36 changes: 14 additions & 22 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,22 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
fit_prior_nsteps::Int = 100::(_ > 0)
end

Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier}
LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier}

# for fit:
MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :),nothing))

MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))

function MMI.reformat(::LaplaceClassifier, X, y)

X = MLJBase.matrix(X) |> permutedims
y = categorical(y)
unique_labels = y.pool.levels
y = Flux.onehotbatch(y, unique_labels) # One-hot encoding


y= reshape(y, 1, :)
# Convert labels to integer format starting from 0 for one-hot encoding
y_plain = MLJBase.int(y[1, :]) .- 1
# One-hot encoding of labels
unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
return X,(y_onehot, y[1])
return X, y
end

#MMI.selectrows(::LaplaceClassifier, I, Xmatrix, y) = (view(Xmatrix, :, I), (view(y[1],I),y[2]))
#MMI.selectrows(::LaplaceRegressor, I, Xmatrix, y) = (view(Xmatrix, :,I), (view(y[1],I),y[2]) )
#for predict:
MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,)
MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,)

@doc """
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
Expand All @@ -85,7 +77,7 @@ Fit a Laplace model using the provided features and target values.
- `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history.
- `report`: A Namedtuple containing the loss history of the fitting process.
"""
function MMI.fit(m::Laplace_Models, verbosity, X, y)
function MMI.fit(m::LaplaceModels, verbosity, X, y)
decode = y[2]

y= y[1]
Expand Down Expand Up @@ -172,7 +164,7 @@ Update the Laplace model using the provided new data points.
- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.
- `report`: A Namedtuple containing the complete loss history of the fitting process.
"""
function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y)
function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
decode = y[2]
y_up=y[1]

Expand Down Expand Up @@ -302,7 +294,7 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
end

@doc """
function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)
If both `m1` and `m2` are of `MLJType`, return `true` if the
following conditions all hold, and `false` otherwise:
Expand All @@ -329,7 +321,7 @@ meaning; see [`deep_properties`](@ref)) for details.
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
"""
function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)
typeof(m1) === typeof(m2) || return false
names = propertynames(m1)
propertynames(m2) === names || return false
Expand Down Expand Up @@ -420,7 +412,7 @@ end
- `loss`: The loss value of the posterior distribution.
"""
function MMI.fitted_params(model::Laplace_Models, fitresult)
function MMI.fitted_params(model::LaplaceModels, fitresult)
la, decode = fitresult
posterior = la.posterior
return (
Expand All @@ -447,7 +439,7 @@ Retrieve the training loss history from the given `report`.
# Returns
- A collection representing the loss history from the training report.
"""
function MMI.training_losses(model::Laplace_Models, report)
function MMI.training_losses(model::LaplaceModels, report)
return report.loss_history
end

Expand All @@ -467,7 +459,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
for LaplaceClassifier:
- `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
"""
function MMI.predict(m::Laplace_Models, fitresult, Xnew)
function MMI.predict(m::LaplaceModels, fitresult, Xnew)
la, decode = fitresult
if typeof(m) == LaplaceRegressor
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
Expand Down
64 changes: 32 additions & 32 deletions test/direct_mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using MLJ
using MLJ:predict,fit!
using LaplaceRedux

cv = CV(; nfolds=3)

@testset "Regression" begin
flux_model = Chain(
Expand Down Expand Up @@ -36,41 +37,40 @@ end


@testset "Classification" begin
# Define the model
flux_model = Chain(
Dense(4, 10, relu),
Dense(10, 3)
)
# Define the model
flux_model = Chain(
Dense(4, 10, relu),
Dense(10, 3)
)

model = LaplaceClassifier(model=flux_model,epochs=50)
model = LaplaceClassifier(model=flux_model,epochs=50)

X, y = @load_iris
mach = machine(model, X, y)
MLJBase.fit!(mach,verbosity=1)
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
petal_width = [2.1, 1.6, 1.9],)
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class
MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
MLJBase.fit!(mach) #testing update function
model.epochs= 50 #changing number of epochs to a lower number
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
X, y = @load_iris
mach = machine(model, X, y)
MLJBase.fit!(mach,verbosity=1)
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
petal_width = [2.1, 1.6, 1.9],)
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class
MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
MLJBase.fit!(mach) #testing update function
model.epochs= 50 #changing number of epochs to a lower number
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)

# Define a different model
flux_model_two = Chain(
Dense(4, 6, relu),
Dense(6, 3)
)
# Define a different model
flux_model_two = Chain(
Dense(4, 6, relu),
Dense(6, 3)
)

model.model = flux_model_two
model.model = flux_model_two

MLJBase.fit!(mach)

MLJBase.fit!(mach)
end

0 comments on commit f872d96

Please sign in to comment.