diff --git a/src/Trees/AbstractTrees_BetaML_interface.jl b/src/Trees/AbstractTrees_BetaML_interface.jl index b61389e5..d6672f65 100644 --- a/src/Trees/AbstractTrees_BetaML_interface.jl +++ b/src/Trees/AbstractTrees_BetaML_interface.jl @@ -11,7 +11,7 @@ For more information see [JuliaAI/DecisionTree.jl](https://github.com/JuliaAI/De The file `src/abstract_trees.jl` in that repo serves as a model implementation. """ -export InfoNode, InfoLeaf, wrap, DecisionNode, Leaf +export InfoNode, InfoLeaf, wrapdn, DecisionNode, Leaf """ These types are introduced so that additional information currently not present in @@ -31,18 +31,18 @@ end AbstractTrees.nodevalue(l::InfoLeaf) = l.leaf # round(l.leaf,sigdigits=4) """ - wrap(node:: DecisionNode, ...) + wrapdn(node:: DecisionNode, ...) Called on the root node of a `DecsionTree` `dc` in order to add visualization information. In case of a `BetaML/DecisionTree` this is typically a list of feature names as follows: -`wdc = wrap(dc, featurenames = ["Colour","Size"])` +`wdc = wrapdn(dc, featurenames = ["Colour","Size"])` """ -wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info) -wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info) -wrap(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrap(mod.par.tree, info) -wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};featurenames=[]) = wrap(m,(featurenames=featurenames,)) +wrapdn(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info) +wrapdn(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info) +wrapdn(mod::DecisionTreeEstimator, info::NamedTuple = NamedTuple()) = wrapdn(mod.par.tree, info) +wrapdn(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};featurenames=[]) = wrapdn(m,(featurenames=featurenames,)) @@ -50,8 +50,8 @@ wrap(m::Union{DecisionNode,Leaf,DecisionTreeEstimator};featurenames=[]) = wrap(m #### Implementation of the `AbstractTrees`-interface AbstractTrees.children(node::InfoNode) = ( - wrap(node.node.trueBranch, node.info), - wrap(node.node.falseBranch, node.info) + wrapdn(node.node.trueBranch, node.info), + wrapdn(node.node.falseBranch, node.info) ) AbstractTrees.children(node::InfoLeaf) = () diff --git a/src/Trees/DecisionTrees.jl b/src/Trees/DecisionTrees.jl index 370d7501..427a512e 100644 --- a/src/Trees/DecisionTrees.jl +++ b/src/Trees/DecisionTrees.jl @@ -254,7 +254,7 @@ Dict{String, Any}("job_is_regression" => 1, "fitted_records" => 6, "max_reached_ julia> using Plots, TreeRecipe, AbstractTrees julia> featurenames = ["Something", "Som else"]; -julia> wrapped_tree = wrap(dtree, featurenames = featurenames); # featurenames is otional +julia> wrapped_tree = wrapdn(dtree, featurenames = featurenames); # featurenames is otional julia> print_tree(wrapped_tree) Som else >= 18.0? ├─ Som else >= 31.0? diff --git a/test/Trees_tests.jl b/test/Trees_tests.jl index 3dd35564..965a29a6 100644 --- a/test/Trees_tests.jl +++ b/test/Trees_tests.jl @@ -57,7 +57,7 @@ x = [1 0.1; 2 0.2; 3 0.3; 4 0.4; 5 0.5] y = ["a","a","b","b","b"] @test findbestgain_sortedvector(x,y,2,x[:,2];mCols=[],currentUncertainty=gini(y),splitting_criterion=gini,rng=copy(TESTRNG)) == 0.3 -wrappedNode = BetaML.wrap(myTree) +wrappedNode = BetaML.wrapdn(myTree) print("Node printing: ") printnode(stdout,wrappedNode) println("") @@ -83,12 +83,12 @@ X = [1.8 2.5; 0.5 20.5; 0.6 18; 0.7 22.8; 0.4 31; 1.7 3.7]; y = 2 .* X[:,1] .- X[:,2] .+ 3; mod = DecisionTreeEstimator(max_depth=2) ŷ = fit!(mod,X,y); -wmod = wrap(mod,featurenames=["dim1","dim2"]) +wmod = wrapdn(mod,featurenames=["dim1","dim2"]) print_tree(wmod) y2 = ["a","b","b","c","b","a"] mod2 = DecisionTreeEstimator(max_depth=2) ŷ2 = fit!(mod2,X,y2); -wmod2 = wrap(mod2,featurenames=["dim1","dim2"]) +wmod2 = wrapdn(mod2,featurenames=["dim1","dim2"]) print_tree(wmod2) #print(myTree) diff --git a/test/Trees_tests_additional.jl b/test/Trees_tests_additional.jl index e6ab32e8..15e23e0a 100644 --- a/test/Trees_tests_additional.jl +++ b/test/Trees_tests_additional.jl @@ -58,7 +58,7 @@ tree = Tree() println("--> add information about feature names") featurenames = ["Color", "Size"] - wrapped_tree = wrap(model, featurenames = featurenames) + wrapped_tree = wrapdn(model, featurenames = featurenames) println("--> plot the tree using the `TreeRecipe`") plt = plot(wrapped_tree) # this calls automatically the `TreeRecipe`