Skip to content

Commit

Permalink
Renamed wrap to wrapdn ("decision node") as new "wrap" function in Base
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvaticus committed Jan 11, 2024
1 parent a11af2e commit fd8ec00
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
18 changes: 9 additions & 9 deletions src/Trees/AbstractTrees_BetaML_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ For more information see [JuliaAI/DecisionTree.jl](https://github.com/JuliaAI/De
The file `src/abstract_trees.jl` in that repo serves as a model implementation.
"""

export InfoNode, InfoLeaf, wrap, DecisionNode, Leaf
export InfoNode, InfoLeaf, wrapdn, DecisionNode, Leaf

"""
These types are introduced so that additional information currently not present in
Expand All @@ -31,27 +31,27 @@ end
AbstractTrees.nodevalue(l::InfoLeaf) = l.leaf # round(l.leaf,sigdigits=4)

"""
wrap(node:: DecisionNode, ...)
wrapdn(node:: DecisionNode, ...)
Called on the root node of a `DecsionTree` `dc` in order to add visualization information.
In case of a `BetaML/DecisionTree` this is typically a list of feature names as follows:
`wdc = wrap(dc, featurenames = ["Colour","Size"])`
`wdc = wrapdn(dc, featurenames = ["Colour","Size"])`
"""

wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
wrap(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrap(mod.par.tree, info)
wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};featurenames=[]) = wrap(m,(featurenames=featurenames,))
wrapdn(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrapdn(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
wrapdn(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrapdn(mod.par.tree, info)
wrapdn(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};featurenames=[]) = wrapdn(m,(featurenames=featurenames,))




#### Implementation of the `AbstractTrees`-interface

AbstractTrees.children(node::InfoNode) = (
wrap(node.node.trueBranch, node.info),
wrap(node.node.falseBranch, node.info)
wrapdn(node.node.trueBranch, node.info),
wrapdn(node.node.falseBranch, node.info)
)
AbstractTrees.children(node::InfoLeaf) = ()

Expand Down
2 changes: 1 addition & 1 deletion src/Trees/DecisionTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ Dict{String, Any}("job_is_regression" => 1, "fitted_records" => 6, "max_reached_
julia> using Plots, TreeRecipe, AbstractTrees
julia> featurenames = ["Something", "Som else"];
julia> wrapped_tree = wrap(dtree, featurenames = featurenames); # featurenames is otional
julia> wrapped_tree = wrapdn(dtree, featurenames = featurenames); # featurenames is otional
julia> print_tree(wrapped_tree)
Som else >= 18.0?
├─ Som else >= 31.0?
Expand Down
6 changes: 3 additions & 3 deletions test/Trees_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ x = [1 0.1; 2 0.2; 3 0.3; 4 0.4; 5 0.5]
y = ["a","a","b","b","b"]
@test findbestgain_sortedvector(x,y,2,x[:,2];mCols=[],currentUncertainty=gini(y),splitting_criterion=gini,rng=copy(TESTRNG)) == 0.3

wrappedNode = BetaML.wrap(myTree)
wrappedNode = BetaML.wrapdn(myTree)
print("Node printing: ")
printnode(stdout,wrappedNode)
println("")
Expand All @@ -83,12 +83,12 @@ X = [1.8 2.5; 0.5 20.5; 0.6 18; 0.7 22.8; 0.4 31; 1.7 3.7];
y = 2 .* X[:,1] .- X[:,2] .+ 3;
mod = DecisionTreeEstimator(max_depth=2)
= fit!(mod,X,y);
wmod = wrap(mod,featurenames=["dim1","dim2"])
wmod = wrapdn(mod,featurenames=["dim1","dim2"])
print_tree(wmod)
y2 = ["a","b","b","c","b","a"]
mod2 = DecisionTreeEstimator(max_depth=2)
ŷ2 = fit!(mod2,X,y2);
wmod2 = wrap(mod2,featurenames=["dim1","dim2"])
wmod2 = wrapdn(mod2,featurenames=["dim1","dim2"])
print_tree(wmod2)

#print(myTree)
Expand Down
2 changes: 1 addition & 1 deletion test/Trees_tests_additional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ tree = Tree()

println("--> add information about feature names")
featurenames = ["Color", "Size"]
wrapped_tree = wrap(model, featurenames = featurenames)
wrapped_tree = wrapdn(model, featurenames = featurenames)

println("--> plot the tree using the `TreeRecipe`")
plt = plot(wrapped_tree) # this calls automatically the `TreeRecipe`
Expand Down

0 comments on commit fd8ec00

Please sign in to comment.