Skip to content

Commit

Permalink
Merge pull request #182 from JuliaAI/flat-params-fix
Browse files Browse the repository at this point in the history
Fix a corner case of flat_params
  • Loading branch information
ablaom authored Oct 19, 2023
2 parents ffb47b1 + da7bb36 commit ddaa91f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.9.2"
version = "1.9.3"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
49 changes: 33 additions & 16 deletions src/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ values, which themselves might be transparent.
Most objects of type `MLJType` are transparent.
```julia
julia> params(EnsembleModel(atom=ConstantClassifier()))
(atom = (target_type = Bool,),
julia> params(EnsembleModel(model=ConstantClassifier()))
(model = (target_type = Bool,),
weights = Float64[],
bagging_fraction = 0.8,
rng_seed = 0,
Expand All @@ -36,25 +36,42 @@ isnotaleaf(m::Model) = length(propertynames(m)) > 0
"""
flat_params(m::Model)
Recursively convert any object subtyping `Model` into a named tuple, keyed on
the property names of `m`. The named tuple is possibly nested because
`flat_params` is recursively applied to the property values, which themselves
might subtype `Model`.
Deconstruct any `Model` instance `model` as a flat named tuple, keyed on property
names. Properties of nested model instances are recursively exposed,.as shown in the
example below. For most `Model` objects, properties are synonymous with fields, but this
is not a hard requirement.
For most `Model` objects, properties are synonymous with fields, but this is
not a hard requirement.
```julia
using MLJModels
using EnsembleModels
tree = (@load DecisionTreeClassifier pkg=DecisionTree)
julia> flat_params(EnsembleModel(model=tree))
(model__max_depth = -1,
model__min_samples_leaf = 1,
model__min_samples_split = 2,
model__min_purity_increase = 0.0,
model__n_subfeatures = 0,
model__post_prune = false,
model__merge_purity_threshold = 1.0,
model__display_depth = 5,
model__feature_importance = :impurity,
model__rng = Random._GLOBAL_RNG(),
atomic_weights = Float64[],
bagging_fraction = 0.8,
rng = Random._GLOBAL_RNG(),
n = 100,
acceleration = CPU1{Nothing}(nothing),
out_of_bag_measure = Any[],)
```
julia> flat_params(EnsembleModel(atom=ConstantClassifier()))
(atom = (target_type = Bool,),
weights = Float64[],
bagging_fraction = 0.8,
rng_seed = 0,
n = 100,
parallel = true,)
"""
flat_params(m; prefix="") = flat_params(m, Val(isnotaleaf(m)); prefix=prefix)
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
function flat_params(m, ::Val{false}; prefix="")
prefix == "" && return NamedTuple()
NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
end
function flat_params(m, ::Val{true}; prefix="")
fields = propertynames(m)
prefix = prefix == "" ? "" : prefix * "__"
Expand Down
4 changes: 4 additions & 0 deletions test/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ end

struct Missy <: Model end

struct EmptyModel <: Model end

@testset "flat_params method" begin

m = ParentModel(1, "parent", ChildModel(2, "child1"),
Expand All @@ -61,5 +63,7 @@ struct Missy <: Model end
second_child__r = 3,
second_child__s = Missy()
)

@test MLJModelInterface.flat_params(EmptyModel()) == NamedTuple()
end
true

0 comments on commit ddaa91f

Please sign in to comment.