Skip to content

Commit

Permalink
Major updates before release
Browse files Browse the repository at this point in the history
  • Loading branch information
Moran79 committed Sep 16, 2024
1 parent e27d874 commit 203f670
Show file tree
Hide file tree
Showing 36 changed files with 806 additions and 1,876 deletions.
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: LDATree
Title: Classification Trees with Linear Discriminant Analysis at
Terminal Nodes
Version: 0.1.2
Version: 0.1.2.9001
Authors@R:
person("Siyu", "Wang", , "[email protected]", role = c("cre", "aut", "cph"),
comment = c(ORCID = "0009-0005-2098-7089"))
Expand All @@ -13,11 +13,12 @@ License: MIT + file LICENSE
URL: https://github.com/Moran79/LDATree, http://iamwangsiyu.com/LDATree/
BugReports: https://github.com/Moran79/LDATree/issues
Imports:
folda,
ggplot2,
lifecycle,
grDevices,
magrittr,
scales,
stats,
utils,
visNetwork
Encoding: UTF-8
LazyData: true
Expand Down
15 changes: 0 additions & 15 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,24 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,SingleTreee)
S3method(plot,Treee)
S3method(predict,SingleTreee)
S3method(predict,Treee)
S3method(predict,ldaGSVD)
S3method(print,Treee)
S3method(print,TreeeNode)
S3method(print,ldaGSVD)
S3method(summary,TreeeNode)
export(Treee)
export(ldaGSVD)
importFrom(lifecycle,deprecated)
importFrom(magrittr,"%>%")
importFrom(stats,.getXlevels)
importFrom(stats,delete.response)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,model.matrix.lm)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,sd)
importFrom(stats,terms)
9 changes: 0 additions & 9 deletions R/LDATree-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@
"_PACKAGE"

## usethis namespace: start
#' @importFrom lifecycle deprecated
#' @importFrom magrittr %>%
#' @importFrom stats .getXlevels
#' @importFrom stats delete.response
#' @importFrom stats model.frame
#' @importFrom stats model.matrix
#' @importFrom stats model.matrix.lm
#' @importFrom stats predict
#' @importFrom stats quantile
#' @importFrom stats sd
#' @importFrom stats terms
## usethis namespace: end
NULL
232 changes: 117 additions & 115 deletions R/Treee.R
Original file line number Diff line number Diff line change
@@ -1,157 +1,159 @@
#' Classification trees with Linear Discriminant Analysis terminal nodes
#' Classification Trees with Uncorrelated Linear Discriminant Analysis Terminal
#' Nodes
#'
#' @description `r lifecycle::badge('experimental')` Fit an LDATree model.
#' This function fits a classification tree where each node has a Uncorrelated
#' Linear Discriminant Analysis (ULDA) model. It can also handle missing values
#' and perform downsampling. The resulting tree can be pruned either through
#' pre-pruning or post-pruning methods.
#'
#' @details
#' @param datX A data frame of predictor variables.
#' @param response A vector of response values corresponding to `datX`.
#' @param ldaType A character string specifying the type of LDA to use. Options
#' are `"forward"` for forward ULDA or `"all"` for full ULDA. Default is
#' `"forward"`.
#' @param nodeModel A character string specifying the type of model used in each
#' node. Options are `"ULDA"` for Uncorrelated LDA, or `"mode"` for predicting
#' based on the most frequent class. Default is `"ULDA"`.
#' @param pruneMethod A character string specifying the pruning method. `"pre"`
#' performs pre-pruning based on p-value thresholds, and `"post"` performs
#' cross-validation-based post-pruning. Default is `"pre"`.
#' @param numberOfPruning An integer specifying the number of folds for
#' cross-validation during post-pruning. Default is `10`.
#' @param maxTreeLevel An integer controlling the maximum depth of the tree.
#' Increasing this value allows for deeper trees with more nodes. Default is
#' `20`.
#' @param minNodeSize An integer controlling the minimum number of samples
#' required in a node. Setting a higher value may lead to earlier stopping and
#' smaller trees. If not specified, it defaults to one plus the number of
#' response classes.
#' @param pThreshold A numeric value used as a threshold for pre-pruning based
#' on p-values. Lower values result in more conservative trees. If not
#' specified, defaults to `0.01` for pre-pruning and `0.51` for post-pruning.
#' @param prior A numeric vector of prior probabilities for each class. If
#' `NULL`, the prior is automatically calculated from the data.
#' @param misClassCost A square matrix \eqn{C}, where each element \eqn{C_{ij}}
#' represents the cost of classifying an observation into class \eqn{i} given
#' that it truly belongs to class \eqn{j}. If `NULL`, a default matrix with
#' equal misclassification costs for all class pairs is used. Default is
#' `NULL`.
#' @param missingMethod A character string specifying how missing values should
#' be handled. Options include `'mean'`, `'median'`, `'meanFlag'`,
#' `'medianFlag'` for numerical variables, and `'mode'`, `'modeFlag'`,
#' `'newLevel'` for factor variables. `'Flag'` options indicate whether a
#' missing flag is added, while `'newLevel'` replaces missing values with a
#' new factor level.
#' @param kSample An integer specifying the number of samples to use for
#' downsampling during tree construction. Set to `-1` to disable downsampling.
#' @param verbose A logical value. If `TRUE`, progress messages and detailed
#' output are printed during tree construction and pruning. Default is
#' `FALSE`.
#'
#' Unlike other classification trees, LDATree integrates LDA throughout the
#' entire tree-growing process. Here is a breakdown of its distinctive features:
#' * The tree searches for the best binary split based on sample quantiles of the first linear discriminant score.
#' @returns An object of class `Treee` containing the fitted tree, which is a
#' list of nodes, each an object of class `TreeeNode`. Each `TreeeNode`
#' contains:
#' * `currentIndex`: The node index in the tree.
#' * `currentLevel`: The depth of the current node in the tree.
#' * `idxRow`, `idxCol`: Row and column indices indicating which part of the original data was used for this node.
#' * `currentLoss`: The training error for this node.
#' * `accuracy`: The training accuracy for this node.
#' * `stopInfo`: Information on why the node stopped growing.
#' * `proportions`: The observed frequency of each class in this node.
#' * `prior`: The (adjusted) class prior probabilities used for ULDA or mode prediction.
#' * `misClassCost`: The misclassification cost matrix used in this node.
#' * `parent`: The index of the parent node.
#' * `children`: A vector of indices of this node’s direct children.
#' * `splitFun`: The splitting function used for this node.
#' * `nodeModel`: Indicates the model fitted at the node (`'ULDA'` or `'mode'`).
#' * `nodePredict`: The fitted model at the node, either a ULDA object or the plurality class.
#' * `alpha`: The p-value from a two-sample t-test used to evaluate the strength of the split.
#' * `childrenTerminal`: A vector of indices representing the terminal nodes that are descendants of this node.
#' * `childrenTerminalLoss`: The total training error accumulated from all nodes listed in `childrenTerminal`.
#'
#' * An LDA/GSVD model is fitted for each terminal node (For more details, refer to [ldaGSVD()]).
#'
#' * Missing values can be imputed using the mean, median, or mode, with optional missing flags available.
#'
#' * By default, the tree employs a direct-stopping rule. However, cross-validation using the alpha-pruning from CART is also provided.
#'
#' @param formula an object of class [formula], which has the form `class ~ x1 +
#' x2 + ...`
#' @param data a data frame that contains both predictors and the response.
#' Missing values are allowed in predictors but not in the response.
#' @param missingMethod Missing value solutions for numerical variables and
#' factor variables. `'mean'`, `'median'`, `'meanFlag'`, `'medianFlag'` are
#' available for numerical variables. `'mode'`, `'modeFlag'`, `'newLevel'` are
#' available for factor variables. The word `'Flag'` in the methods indicates
#' whether a missing flag is added or not. The `'newLevel'` method means that
#' all missing values are replaced with a new level rather than imputing them
#' to another existing value.
#' @param maxTreeLevel controls the largest tree size possible for either a
#' direct-stopping tree or a CV-pruned tree. Adding one extra level (depth)
#' introduces an additional layer of nodes at the bottom of the current tree.
#' e.g., when the maximum level is 1 (or 2), the maximum tree size is 3 (or
#' 7).
#' @param minNodeSize controls the minimum node size. Think carefully before
#' changing this value. Setting a large number might result in early stopping
#' and reduced accuracy. By default, it's set to one plus the number of
#' classes in the response variable.
#' @param verbose a logical. If TRUE, the function provides additional
#' diagnostic messages or detailed output about its progress or internal
#' workings. Default is FALSE, where the function runs silently without
#' additional output.
#'
#' @returns An object of class `Treee` containing the following components:
#' * `formula`: the formula passed to the [Treee()]
#' * `treee`: a list of all the tree nodes, and each node is an object of class `TreeeNode`.
#' * `missingMethod`: the missingMethod passed to the [Treee()]
#'
#' An object of class `TreeeNode` containing the following components:
#' * `currentIndex`: the node index of the current node
#' * `currentLevel`: the level of the current node in the tree
#' * `idxRow`, `idxCol`: the row and column indices showing which portion of data is used in the current node
#' * `currentLoss`: ?
#' * `accuracy`: the training accuracy of the current node
#' * `stopFlag`: ?
#' * `proportions`: shows the observed frequency for each class
#' * `parent`: the node index of its parent
#' * `children`: the node indices of its direct children (not including its children's children)
#' * `misReference`: a data frame, serves as the reference for missing value imputation
#' * `splitFun`: ?
#' * `nodeModel`: one of `'mode'` or `'LDA'`. It shows the type of predictive model fitted in the current node
#' * `nodePredict`: the fitted predictive model in the current node. It is an object of class `ldaGSVD` if LDA is fitted. If `nodeModel = 'mode'`, then it is a vector of length one, showing the plurality class.
#' @export
#'
#' @references Wang, S. (2024). A New Forward Discriminant Analysis Framework
#' Based On Pillai's Trace and ULDA. \emph{arXiv preprint arXiv:2409.03136}.
#' Available at \url{https://arxiv.org/abs/2409.03136}.
#'
#' @examples
#' fit <- Treee(Species~., data = iris)
#' fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE)
#' # Use cross-validation to prune the tree
#' fitCV <- Treee(Species~., data = iris)
#' # prediction
#' predict(fit,iris)
#' # plot the overall tree
#' plot(fit)
#' # plot a certain node
#' plot(fit, iris, node = 1)
#' fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE)
#' head(predict(fit, iris)) # prediction
#' plot(fit) # plot the overall tree
#' plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node
Treee <- function(datX,
response,
ldaType = c("step", "all"),
nodeModel = c("LDA", "mode"),
missingMethod = c("medianFlag", "newLevel"),
prior = NULL,
misClassCost = NULL,
pruneMethod = c("post", "pre", "pre-post"),
numberOfPruning = 10,
ldaType = c("forward", "all"),
nodeModel = c("ULDA", "mode"),
pruneMethod = c("pre", "post"),
numberOfPruning = 10L,
maxTreeLevel = 20L,
minNodeSize = NULL,
pThreshold = 0.1,
verbose = TRUE,
kSample = 1e7){
pThreshold = NULL,
prior = NULL,
misClassCost = NULL,
missingMethod = c("medianFlag", "newLevel"),
kSample = -1,
verbose = TRUE){ # Change verbose to FALSE before CRAN submission

# Standardize the Arguments -----------------------------------------------

response <- as.factor(response) # make it a factor
ldaType <- match.arg(ldaType, c("step", "all"))
nodeModel <- match.arg(nodeModel, c("LDA", "mode"))
missingMethod <- c(match.arg(missingMethod[1], c("mean", "median", "meanFlag", "medianFlag")),
match.arg(missingMethod[2], c("mode", "modeFlag", "newLevel")))
prior <- checkPriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, internal = TRUE)
pruneMethod <- match.arg(pruneMethod, c("post", "pre", "pre-post"))
if(is.null(minNodeSize)) minNodeSize <- nlevels(response) + 1 # minNodeSize: If not specified, set to J+1

#> Does not support ordered factors
for(i in seq_along(datX)){
if("ordered" %in% class(datX[,i])) class(datX[,i]) <- "factor"
datX <- data.frame(datX) # change to data.frame, remove the potential tibble attribute
for(i in seq_along(datX)){ # remove ordered factors
if(inherits(datX[,i], c("ordered"))) class(datX[,i]) <- "factor"
}
# Remove NAs in the response
idxNonNA <- which(!is.na(response)); response <- droplevels(factor(response[idxNonNA], ordered = FALSE))
datX <- datX[idxNonNA, , drop = FALSE]

ldaType <- match.arg(ldaType, c("forward", "all"))
nodeModel <- match.arg(nodeModel, c("ULDA", "mode"))
pruneMethod <- match.arg(pruneMethod, c("pre", "post"))
if(is.null(minNodeSize)) minNodeSize <- nlevels(response) + 1 # minNodeSize: If not specified, set to J+1
if(is.null(pThreshold)) pThreshold <- ifelse(pruneMethod == "pre", 0.01, 0.51)
priorAndMisClassCost <- updatePriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, insideNode = FALSE)
prior <- priorAndMisClassCost$prior; misClassCost <- priorAndMisClassCost$misClassCost

# Build Different Trees ---------------------------------------------------

treeeNow = new_SingleTreee(datX = datX,
response = response,
ldaType = ldaType,
nodeModel = nodeModel,
missingMethod = missingMethod,
prior = prior,
maxTreeLevel = maxTreeLevel,
minNodeSize = minNodeSize,
pThreshold = ifelse(pruneMethod == "pre", pThreshold, 0.51),
verbose = verbose,
kSample = kSample)

if(pruneMethod == "pre-post"){
#> Step Ahead and prune back, 0.51 as a looser bound
#> Discard due to unnecessary
treeeNow <- updateAlphaInTree(treeeNow)
treeeNow <- pruneTreee(treeeNow, pThreshold)
treeeNow <- dropNodes(treeeNow)
}
pThreshold = pThreshold,
prior = prior,
misClassCost = misClassCost,
missingMethod = missingMethod,
kSample = kSample,
verbose = verbose)

if(verbose) cat(paste('\nThe pre-pruned LDA tree is completed. It has', length(treeeNow), 'nodes.\n'))

finalTreee <- structure(list(treee = treeeNow,
missingMethod = missingMethod), class = "Treee")


# Pruning -----------------------------------------------------------------

if(pruneMethod == "post" & length(treeeNow) > 1){
pruningOutput <- prune(oldTreee = treeeNow,
numberOfPruning = numberOfPruning,
datX = datX,
response = response,
ldaType = ldaType,
nodeModel = nodeModel,
missingMethod = missingMethod,
prior = prior,
numberOfPruning = numberOfPruning,
maxTreeLevel = maxTreeLevel,
minNodeSize = minNodeSize,
pThreshold = 0.51,
verbose = verbose,
kSample = kSample)

# Add something to the finalTreee
finalTreee$treee <- pruningOutput$treeeNew
finalTreee$CV_Table <- pruningOutput$CV_Table
pThreshold = pThreshold,
prior = prior,
misClassCost = misClassCost,
missingMethod = missingMethod,
kSample = kSample,
verbose = verbose)

if(verbose) cat(paste('\nThe post-pruned tree is completed. It has', length(finalTreee$treee), 'nodes.\n'))
treeeNow <- pruningOutput$treeeNew
attr(treeeNow, "CV_Table") <- pruningOutput$CV_Table
if(verbose) cat(paste('\nThe post-pruned tree is completed. It has', length(treeeNow), 'nodes.\n'))
}

return(finalTreee)
return(treeeNow)
}
Loading

0 comments on commit 203f670

Please sign in to comment.