You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, i was trying to implement an interface between laplaceredux and mlj but i am facing an issue with implementing the probabilistic classifier model. In particular, i have not fully understood how to correctly use UnivariateFinite.
I have imported the packages
using Flux
using Random
using Tables
using LinearAlgebra
using LaplaceRedux
using MLJBase
import MLJModelInterface as MMI
using Distributions: Normal
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
decode = y[1]
y_plain = MLJBase.int(y) .- 1
y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) )
data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
for epoch in 1:m.epochs
Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
m.flux_loss(model(X), y_onehot)
end
end
la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:classification,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
)
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader )
optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
report = (status="success", message="Model fitted successfully")
cache = nothing
return ((la,decode), cache, report)
end
and the predict function
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
predictions = LaplaceRedux.predict(
la,
Xnew;
link_approx=m.link_approx,
ret_distr=false)
return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool= decode, augment=true) for prediction in predictions]
end
but when i run predict i get the error
Warning: Ignoring value of pool as the specified support defines one already.
and the error is just the last line with UnivariateFinite.
The text was updated successfully, but these errors were encountered:
I think you can just drop pool=decode as MLJBase.classes(decode) is a categorical vector, which therefore already includes the pool. You only need to specify a pool if the first argument of UnivariateFinite is a raw vector (not a categorical vector).
(Elements of the first argument not in the pool still get assigned a probability, namely zero.)
Hi, i was trying to implement an interface between laplaceredux and mlj but i am facing an issue with implementing the probabilistic classifier model. In particular, i have not fully understood how to correctly use UnivariateFinite.
I have imported the packages
created the model
written a fit function
and the predict function
but when i run predict i get the error
Warning: Ignoring value of
pool
as the specified support defines one already.and the error is just the last line with UnivariateFinite.
The text was updated successfully, but these errors were encountered: