From 9217f666e3dd75e967b2ba22c7113de5e9ad7b1c Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 17 Oct 2023 09:40:12 +1100 Subject: [PATCH 1/2] fix a corner case in flat_params --- src/parameter_inspection.jl | 49 ++++++++++++++++++++++++------------ test/parameter_inspection.jl | 4 +++ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/parameter_inspection.jl b/src/parameter_inspection.jl index c1e74b8..bc5d2ed 100644 --- a/src/parameter_inspection.jl +++ b/src/parameter_inspection.jl @@ -13,8 +13,8 @@ values, which themselves might be transparent. Most objects of type `MLJType` are transparent. ```julia -julia> params(EnsembleModel(atom=ConstantClassifier())) -(atom = (target_type = Bool,), +julia> params(EnsembleModel(model=ConstantClassifier())) +(model = (target_type = Bool,), weights = Float64[], bagging_fraction = 0.8, rng_seed = 0, @@ -36,25 +36,42 @@ isnotaleaf(m::Model) = length(propertynames(m)) > 0 """ flat_params(m::Model) -Recursively convert any object subtyping `Model` into a named tuple, keyed on -the property names of `m`. The named tuple is possibly nested because -`flat_params` is recursively applied to the property values, which themselves -might subtype `Model`. +Deconstruct any `Model` instance `model` as a flat named tuple, keyed on property +names. Properties of nested model instances are recursively exposed,.as shown in the +example below. For most `Model` objects, properties are synonymous with fields, but this +is not a hard requirement. -For most `Model` objects, properties are synonymous with fields, but this is -not a hard requirement. +```julia +using MLJModels +using EnsembleModels +tree = (@load DecisionTreeClassifier pkg=DecisionTree) + +julia> flat_params(EnsembleModel(model=tree)) +(model__max_depth = -1, + model__min_samples_leaf = 1, + model__min_samples_split = 2, + model__min_purity_increase = 0.0, + model__n_subfeatures = 0, + model__post_prune = false, + model__merge_purity_threshold = 1.0, + model__display_depth = 5, + model__feature_importance = :impurity, + model__rng = Random._GLOBAL_RNG(), + atomic_weights = Float64[], + bagging_fraction = 0.8, + rng = Random._GLOBAL_RNG(), + n = 100, + acceleration = CPU1{Nothing}(nothing), + out_of_bag_measure = Any[],) +``` - julia> flat_params(EnsembleModel(atom=ConstantClassifier())) - (atom = (target_type = Bool,), - weights = Float64[], - bagging_fraction = 0.8, - rng_seed = 0, - n = 100, - parallel = true,) """ flat_params(m; prefix="") = flat_params(m, Val(isnotaleaf(m)); prefix=prefix) -flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,)) +function flat_params(m, ::Val{false}; prefix="") + prefix == "" && return NamedTuple() + NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,)) +end function flat_params(m, ::Val{true}; prefix="") fields = propertynames(m) prefix = prefix == "" ? "" : prefix * "__" diff --git a/test/parameter_inspection.jl b/test/parameter_inspection.jl index ab1c2b3..a4e09bd 100644 --- a/test/parameter_inspection.jl +++ b/test/parameter_inspection.jl @@ -48,6 +48,8 @@ end struct Missy <: Model end +struct EmptyModel <: Model end + @testset "flat_params method" begin m = ParentModel(1, "parent", ChildModel(2, "child1"), @@ -61,5 +63,7 @@ struct Missy <: Model end second_child__r = 3, second_child__s = Missy() ) + + @test MLJModelInterface.flat_params(EmptyModel()) == NamedTuple() end true From da7bb369b74b848248927e19f671aaaaeb055d0f Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 17 Oct 2023 09:41:12 +1100 Subject: [PATCH 2/2] bump 1.9.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 868cd03..eb2edbb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJModelInterface" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" authors = ["Thibaut Lienart and Anthony Blaom"] -version = "1.9.2" +version = "1.9.3" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"