Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Sep 12, 2024
1 parent e8e96d1 commit 255cf19
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux
flux_model::Flux.Chain= nothing
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
Expand Down Expand Up @@ -105,7 +105,7 @@ The model also has the following parameters:
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic

model::Flux
#model::Flux
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
hessian_structure::Union{HessianStructure,Symbol,String} =
Expand Down Expand Up @@ -172,18 +172,18 @@ MLJBase.metadata_model(
LaplaceClassifier;
input_scitype=Union{
AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
load_path="LaplaceRedux.LaplaceClassification",
load_path="LaplaceRedux.LaplaceClassifier",
)
# metadata for each model,
MLJBase.metadata_model(
LaplaceRegressor;
input_scitype=Union{
AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{MLJBase.Continuous},
load_path="LaplaceRedux.MLJFlux.LaplaceRegression",
load_path="LaplaceRedux.LaplaceRegressor",
)

0 comments on commit 255cf19

Please sign in to comment.