-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
36 changed files
with
806 additions
and
1,876 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
Package: LDATree | ||
Title: Classification Trees with Linear Discriminant Analysis at | ||
Terminal Nodes | ||
Version: 0.1.2 | ||
Version: 0.1.2.9001 | ||
Authors@R: | ||
person("Siyu", "Wang", , "[email protected]", role = c("cre", "aut", "cph"), | ||
comment = c(ORCID = "0009-0005-2098-7089")) | ||
|
@@ -13,11 +13,12 @@ License: MIT + file LICENSE | |
URL: https://github.com/Moran79/LDATree, http://iamwangsiyu.com/LDATree/ | ||
BugReports: https://github.com/Moran79/LDATree/issues | ||
Imports: | ||
folda, | ||
ggplot2, | ||
lifecycle, | ||
grDevices, | ||
magrittr, | ||
scales, | ||
stats, | ||
utils, | ||
visNetwork | ||
Encoding: UTF-8 | ||
LazyData: true | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,9 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(plot,SingleTreee) | ||
S3method(plot,Treee) | ||
S3method(predict,SingleTreee) | ||
S3method(predict,Treee) | ||
S3method(predict,ldaGSVD) | ||
S3method(print,Treee) | ||
S3method(print,TreeeNode) | ||
S3method(print,ldaGSVD) | ||
S3method(summary,TreeeNode) | ||
export(Treee) | ||
export(ldaGSVD) | ||
importFrom(lifecycle,deprecated) | ||
importFrom(magrittr,"%>%") | ||
importFrom(stats,.getXlevels) | ||
importFrom(stats,delete.response) | ||
importFrom(stats,model.frame) | ||
importFrom(stats,model.matrix) | ||
importFrom(stats,model.matrix.lm) | ||
importFrom(stats,predict) | ||
importFrom(stats,quantile) | ||
importFrom(stats,sd) | ||
importFrom(stats,terms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,157 +1,159 @@ | ||
#' Classification trees with Linear Discriminant Analysis terminal nodes | ||
#' Classification Trees with Uncorrelated Linear Discriminant Analysis Terminal | ||
#' Nodes | ||
#' | ||
#' @description `r lifecycle::badge('experimental')` Fit an LDATree model. | ||
#' This function fits a classification tree where each node has a Uncorrelated | ||
#' Linear Discriminant Analysis (ULDA) model. It can also handle missing values | ||
#' and perform downsampling. The resulting tree can be pruned either through | ||
#' pre-pruning or post-pruning methods. | ||
#' | ||
#' @details | ||
#' @param datX A data frame of predictor variables. | ||
#' @param response A vector of response values corresponding to `datX`. | ||
#' @param ldaType A character string specifying the type of LDA to use. Options | ||
#' are `"forward"` for forward ULDA or `"all"` for full ULDA. Default is | ||
#' `"forward"`. | ||
#' @param nodeModel A character string specifying the type of model used in each | ||
#' node. Options are `"ULDA"` for Uncorrelated LDA, or `"mode"` for predicting | ||
#' based on the most frequent class. Default is `"ULDA"`. | ||
#' @param pruneMethod A character string specifying the pruning method. `"pre"` | ||
#' performs pre-pruning based on p-value thresholds, and `"post"` performs | ||
#' cross-validation-based post-pruning. Default is `"pre"`. | ||
#' @param numberOfPruning An integer specifying the number of folds for | ||
#' cross-validation during post-pruning. Default is `10`. | ||
#' @param maxTreeLevel An integer controlling the maximum depth of the tree. | ||
#' Increasing this value allows for deeper trees with more nodes. Default is | ||
#' `20`. | ||
#' @param minNodeSize An integer controlling the minimum number of samples | ||
#' required in a node. Setting a higher value may lead to earlier stopping and | ||
#' smaller trees. If not specified, it defaults to one plus the number of | ||
#' response classes. | ||
#' @param pThreshold A numeric value used as a threshold for pre-pruning based | ||
#' on p-values. Lower values result in more conservative trees. If not | ||
#' specified, defaults to `0.01` for pre-pruning and `0.51` for post-pruning. | ||
#' @param prior A numeric vector of prior probabilities for each class. If | ||
#' `NULL`, the prior is automatically calculated from the data. | ||
#' @param misClassCost A square matrix \eqn{C}, where each element \eqn{C_{ij}} | ||
#' represents the cost of classifying an observation into class \eqn{i} given | ||
#' that it truly belongs to class \eqn{j}. If `NULL`, a default matrix with | ||
#' equal misclassification costs for all class pairs is used. Default is | ||
#' `NULL`. | ||
#' @param missingMethod A character string specifying how missing values should | ||
#' be handled. Options include `'mean'`, `'median'`, `'meanFlag'`, | ||
#' `'medianFlag'` for numerical variables, and `'mode'`, `'modeFlag'`, | ||
#' `'newLevel'` for factor variables. `'Flag'` options indicate whether a | ||
#' missing flag is added, while `'newLevel'` replaces missing values with a | ||
#' new factor level. | ||
#' @param kSample An integer specifying the number of samples to use for | ||
#' downsampling during tree construction. Set to `-1` to disable downsampling. | ||
#' @param verbose A logical value. If `TRUE`, progress messages and detailed | ||
#' output are printed during tree construction and pruning. Default is | ||
#' `FALSE`. | ||
#' | ||
#' Unlike other classification trees, LDATree integrates LDA throughout the | ||
#' entire tree-growing process. Here is a breakdown of its distinctive features: | ||
#' * The tree searches for the best binary split based on sample quantiles of the first linear discriminant score. | ||
#' @returns An object of class `Treee` containing the fitted tree, which is a | ||
#' list of nodes, each an object of class `TreeeNode`. Each `TreeeNode` | ||
#' contains: | ||
#' * `currentIndex`: The node index in the tree. | ||
#' * `currentLevel`: The depth of the current node in the tree. | ||
#' * `idxRow`, `idxCol`: Row and column indices indicating which part of the original data was used for this node. | ||
#' * `currentLoss`: The training error for this node. | ||
#' * `accuracy`: The training accuracy for this node. | ||
#' * `stopInfo`: Information on why the node stopped growing. | ||
#' * `proportions`: The observed frequency of each class in this node. | ||
#' * `prior`: The (adjusted) class prior probabilities used for ULDA or mode prediction. | ||
#' * `misClassCost`: The misclassification cost matrix used in this node. | ||
#' * `parent`: The index of the parent node. | ||
#' * `children`: A vector of indices of this node’s direct children. | ||
#' * `splitFun`: The splitting function used for this node. | ||
#' * `nodeModel`: Indicates the model fitted at the node (`'ULDA'` or `'mode'`). | ||
#' * `nodePredict`: The fitted model at the node, either a ULDA object or the plurality class. | ||
#' * `alpha`: The p-value from a two-sample t-test used to evaluate the strength of the split. | ||
#' * `childrenTerminal`: A vector of indices representing the terminal nodes that are descendants of this node. | ||
#' * `childrenTerminalLoss`: The total training error accumulated from all nodes listed in `childrenTerminal`. | ||
#' | ||
#' * An LDA/GSVD model is fitted for each terminal node (For more details, refer to [ldaGSVD()]). | ||
#' | ||
#' * Missing values can be imputed using the mean, median, or mode, with optional missing flags available. | ||
#' | ||
#' * By default, the tree employs a direct-stopping rule. However, cross-validation using the alpha-pruning from CART is also provided. | ||
#' | ||
#' @param formula an object of class [formula], which has the form `class ~ x1 + | ||
#' x2 + ...` | ||
#' @param data a data frame that contains both predictors and the response. | ||
#' Missing values are allowed in predictors but not in the response. | ||
#' @param missingMethod Missing value solutions for numerical variables and | ||
#' factor variables. `'mean'`, `'median'`, `'meanFlag'`, `'medianFlag'` are | ||
#' available for numerical variables. `'mode'`, `'modeFlag'`, `'newLevel'` are | ||
#' available for factor variables. The word `'Flag'` in the methods indicates | ||
#' whether a missing flag is added or not. The `'newLevel'` method means that | ||
#' all missing values are replaced with a new level rather than imputing them | ||
#' to another existing value. | ||
#' @param maxTreeLevel controls the largest tree size possible for either a | ||
#' direct-stopping tree or a CV-pruned tree. Adding one extra level (depth) | ||
#' introduces an additional layer of nodes at the bottom of the current tree. | ||
#' e.g., when the maximum level is 1 (or 2), the maximum tree size is 3 (or | ||
#' 7). | ||
#' @param minNodeSize controls the minimum node size. Think carefully before | ||
#' changing this value. Setting a large number might result in early stopping | ||
#' and reduced accuracy. By default, it's set to one plus the number of | ||
#' classes in the response variable. | ||
#' @param verbose a logical. If TRUE, the function provides additional | ||
#' diagnostic messages or detailed output about its progress or internal | ||
#' workings. Default is FALSE, where the function runs silently without | ||
#' additional output. | ||
#' | ||
#' @returns An object of class `Treee` containing the following components: | ||
#' * `formula`: the formula passed to the [Treee()] | ||
#' * `treee`: a list of all the tree nodes, and each node is an object of class `TreeeNode`. | ||
#' * `missingMethod`: the missingMethod passed to the [Treee()] | ||
#' | ||
#' An object of class `TreeeNode` containing the following components: | ||
#' * `currentIndex`: the node index of the current node | ||
#' * `currentLevel`: the level of the current node in the tree | ||
#' * `idxRow`, `idxCol`: the row and column indices showing which portion of data is used in the current node | ||
#' * `currentLoss`: ? | ||
#' * `accuracy`: the training accuracy of the current node | ||
#' * `stopFlag`: ? | ||
#' * `proportions`: shows the observed frequency for each class | ||
#' * `parent`: the node index of its parent | ||
#' * `children`: the node indices of its direct children (not including its children's children) | ||
#' * `misReference`: a data frame, serves as the reference for missing value imputation | ||
#' * `splitFun`: ? | ||
#' * `nodeModel`: one of `'mode'` or `'LDA'`. It shows the type of predictive model fitted in the current node | ||
#' * `nodePredict`: the fitted predictive model in the current node. It is an object of class `ldaGSVD` if LDA is fitted. If `nodeModel = 'mode'`, then it is a vector of length one, showing the plurality class. | ||
#' @export | ||
#' | ||
#' @references Wang, S. (2024). A New Forward Discriminant Analysis Framework | ||
#' Based On Pillai's Trace and ULDA. \emph{arXiv preprint arXiv:2409.03136}. | ||
#' Available at \url{https://arxiv.org/abs/2409.03136}. | ||
#' | ||
#' @examples | ||
#' fit <- Treee(Species~., data = iris) | ||
#' fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) | ||
#' # Use cross-validation to prune the tree | ||
#' fitCV <- Treee(Species~., data = iris) | ||
#' # prediction | ||
#' predict(fit,iris) | ||
#' # plot the overall tree | ||
#' plot(fit) | ||
#' # plot a certain node | ||
#' plot(fit, iris, node = 1) | ||
#' fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE) | ||
#' head(predict(fit, iris)) # prediction | ||
#' plot(fit) # plot the overall tree | ||
#' plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node | ||
Treee <- function(datX, | ||
response, | ||
ldaType = c("step", "all"), | ||
nodeModel = c("LDA", "mode"), | ||
missingMethod = c("medianFlag", "newLevel"), | ||
prior = NULL, | ||
misClassCost = NULL, | ||
pruneMethod = c("post", "pre", "pre-post"), | ||
numberOfPruning = 10, | ||
ldaType = c("forward", "all"), | ||
nodeModel = c("ULDA", "mode"), | ||
pruneMethod = c("pre", "post"), | ||
numberOfPruning = 10L, | ||
maxTreeLevel = 20L, | ||
minNodeSize = NULL, | ||
pThreshold = 0.1, | ||
verbose = TRUE, | ||
kSample = 1e7){ | ||
pThreshold = NULL, | ||
prior = NULL, | ||
misClassCost = NULL, | ||
missingMethod = c("medianFlag", "newLevel"), | ||
kSample = -1, | ||
verbose = TRUE){ # Change verbose to FALSE before CRAN submission | ||
|
||
# Standardize the Arguments ----------------------------------------------- | ||
|
||
response <- as.factor(response) # make it a factor | ||
ldaType <- match.arg(ldaType, c("step", "all")) | ||
nodeModel <- match.arg(nodeModel, c("LDA", "mode")) | ||
missingMethod <- c(match.arg(missingMethod[1], c("mean", "median", "meanFlag", "medianFlag")), | ||
match.arg(missingMethod[2], c("mode", "modeFlag", "newLevel"))) | ||
prior <- checkPriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, internal = TRUE) | ||
pruneMethod <- match.arg(pruneMethod, c("post", "pre", "pre-post")) | ||
if(is.null(minNodeSize)) minNodeSize <- nlevels(response) + 1 # minNodeSize: If not specified, set to J+1 | ||
|
||
#> Does not support ordered factors | ||
for(i in seq_along(datX)){ | ||
if("ordered" %in% class(datX[,i])) class(datX[,i]) <- "factor" | ||
datX <- data.frame(datX) # change to data.frame, remove the potential tibble attribute | ||
for(i in seq_along(datX)){ # remove ordered factors | ||
if(inherits(datX[,i], c("ordered"))) class(datX[,i]) <- "factor" | ||
} | ||
# Remove NAs in the response | ||
idxNonNA <- which(!is.na(response)); response <- droplevels(factor(response[idxNonNA], ordered = FALSE)) | ||
datX <- datX[idxNonNA, , drop = FALSE] | ||
|
||
ldaType <- match.arg(ldaType, c("forward", "all")) | ||
nodeModel <- match.arg(nodeModel, c("ULDA", "mode")) | ||
pruneMethod <- match.arg(pruneMethod, c("pre", "post")) | ||
if(is.null(minNodeSize)) minNodeSize <- nlevels(response) + 1 # minNodeSize: If not specified, set to J+1 | ||
if(is.null(pThreshold)) pThreshold <- ifelse(pruneMethod == "pre", 0.01, 0.51) | ||
priorAndMisClassCost <- updatePriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, insideNode = FALSE) | ||
prior <- priorAndMisClassCost$prior; misClassCost <- priorAndMisClassCost$misClassCost | ||
|
||
# Build Different Trees --------------------------------------------------- | ||
|
||
treeeNow = new_SingleTreee(datX = datX, | ||
response = response, | ||
ldaType = ldaType, | ||
nodeModel = nodeModel, | ||
missingMethod = missingMethod, | ||
prior = prior, | ||
maxTreeLevel = maxTreeLevel, | ||
minNodeSize = minNodeSize, | ||
pThreshold = ifelse(pruneMethod == "pre", pThreshold, 0.51), | ||
verbose = verbose, | ||
kSample = kSample) | ||
|
||
if(pruneMethod == "pre-post"){ | ||
#> Step Ahead and prune back, 0.51 as a looser bound | ||
#> Discard due to unnecessary | ||
treeeNow <- updateAlphaInTree(treeeNow) | ||
treeeNow <- pruneTreee(treeeNow, pThreshold) | ||
treeeNow <- dropNodes(treeeNow) | ||
} | ||
pThreshold = pThreshold, | ||
prior = prior, | ||
misClassCost = misClassCost, | ||
missingMethod = missingMethod, | ||
kSample = kSample, | ||
verbose = verbose) | ||
|
||
if(verbose) cat(paste('\nThe pre-pruned LDA tree is completed. It has', length(treeeNow), 'nodes.\n')) | ||
|
||
finalTreee <- structure(list(treee = treeeNow, | ||
missingMethod = missingMethod), class = "Treee") | ||
|
||
|
||
# Pruning ----------------------------------------------------------------- | ||
|
||
if(pruneMethod == "post" & length(treeeNow) > 1){ | ||
pruningOutput <- prune(oldTreee = treeeNow, | ||
numberOfPruning = numberOfPruning, | ||
datX = datX, | ||
response = response, | ||
ldaType = ldaType, | ||
nodeModel = nodeModel, | ||
missingMethod = missingMethod, | ||
prior = prior, | ||
numberOfPruning = numberOfPruning, | ||
maxTreeLevel = maxTreeLevel, | ||
minNodeSize = minNodeSize, | ||
pThreshold = 0.51, | ||
verbose = verbose, | ||
kSample = kSample) | ||
|
||
# Add something to the finalTreee | ||
finalTreee$treee <- pruningOutput$treeeNew | ||
finalTreee$CV_Table <- pruningOutput$CV_Table | ||
pThreshold = pThreshold, | ||
prior = prior, | ||
misClassCost = misClassCost, | ||
missingMethod = missingMethod, | ||
kSample = kSample, | ||
verbose = verbose) | ||
|
||
if(verbose) cat(paste('\nThe post-pruned tree is completed. It has', length(finalTreee$treee), 'nodes.\n')) | ||
treeeNow <- pruningOutput$treeeNew | ||
attr(treeeNow, "CV_Table") <- pruningOutput$CV_Table | ||
if(verbose) cat(paste('\nThe post-pruned tree is completed. It has', length(treeeNow), 'nodes.\n')) | ||
} | ||
|
||
return(finalTreee) | ||
return(treeeNow) | ||
} |
Oops, something went wrong.